wat_service 0.1.0

WebAssembly Text Format language service.
Documentation
use crate::{
    binder::{SymbolItemKind, SymbolTable},
    data_set::{self, OperandType},
    files::FilesCtx,
    helpers,
    types_analyzer::TypesAnalyzerCtx,
    InternUri, LanguageService,
};
use line_index::LineIndex;
use lsp_types::{Diagnostic, DiagnosticRelatedInformation, DiagnosticSeverity, Location};
use rowan::{
    ast::{
        support::{children, token},
        AstNode,
    },
    TextRange,
};
use wat_syntax::{
    ast::{Instr, PlainInstr},
    SyntaxKind, SyntaxNode,
};

pub fn check_folded(
    diags: &mut Vec<Diagnostic>,
    service: &LanguageService,
    uri: InternUri,
    line_index: &LineIndex,
    node: SyntaxNode,
    symbol_table: &SymbolTable,
) {
    let Some(instr) = PlainInstr::cast(node) else {
        return;
    };
    let Some(instr_name) = instr.instr_name() else {
        return;
    };
    if instr.l_paren_token().is_none() {
        return;
    }
    let is_call = instr_name.text() == "call";
    let meta = data_set::INSTR_METAS.get(instr_name.text());

    let skipped_count = if is_call {
        1
    } else {
        meta.map(|meta| meta.operands_count).unwrap_or_default()
    };
    let received_types =
        instr
            .operands()
            .skip(skipped_count)
            .fold(vec![], |mut received, operand| {
                if let Some(instr) = operand.instr() {
                    if let Some(types) = resolve_type(service, uri, symbol_table, &instr) {
                        received.extend(types.into_iter().map(|ty| (ty, operand.clone())));
                    }
                } else if meta.is_some() {
                    diags.push(Diagnostic {
                        range: helpers::rowan_range_to_lsp_range(
                            line_index,
                            operand.syntax().text_range(),
                        ),
                        severity: Some(DiagnosticSeverity::ERROR),
                        source: Some("wat".into()),
                        message: "expected instr".into(),
                        ..Default::default()
                    });
                }
                received
            });
    let Some(params) = resolve_expected_types(service, uri, symbol_table, &instr, meta) else {
        return;
    };

    let expected_count = params.len() + skipped_count;
    let received_count = received_types.len()
        + instr
            .operands()
            .filter(|operand| operand.instr().is_none())
            .count();
    if expected_count != received_count {
        diags.push(Diagnostic {
            range: helpers::rowan_range_to_lsp_range(line_index, instr.syntax().text_range()),
            severity: Some(DiagnosticSeverity::ERROR),
            source: Some("wat".into()),
            message: build_incorrect_operands_count_msg(expected_count, received_count),
            ..Default::default()
        });
    }

    let type_mismatches = params.iter().zip(received_types.iter()).filter_map(|pair| {
        if let ((OperandType::Val(expected), related), (OperandType::Val(received), operand)) = pair
        {
            if expected == received {
                None
            } else {
                Some(Diagnostic {
                    range: helpers::rowan_range_to_lsp_range(
                        line_index,
                        operand.syntax().text_range(),
                    ),
                    severity: Some(DiagnosticSeverity::ERROR),
                    source: Some("wat".into()),
                    message: format!("expected type `{expected}`, found `{received}`"),
                    related_information: related.as_ref().map(|(range, message)| {
                        vec![DiagnosticRelatedInformation {
                            location: Location {
                                uri: service.lookup_uri(uri),
                                range: helpers::rowan_range_to_lsp_range(line_index, *range),
                            },
                            message: message.clone(),
                        }]
                    }),
                    ..Default::default()
                })
            }
        } else {
            None
        }
    });
    diags.extend(type_mismatches);
}

pub fn check_stacked(
    diags: &mut Vec<Diagnostic>,
    service: &LanguageService,
    uri: InternUri,
    line_index: &LineIndex,
    node: &SyntaxNode,
    symbol_table: &SymbolTable,
) {
    if node
        .children()
        .filter(|child| {
            matches!(
                child.kind(),
                SyntaxKind::PLAIN_INSTR
                    | SyntaxKind::BLOCK_BLOCK
                    | SyntaxKind::BLOCK_IF
                    | SyntaxKind::BLOCK_LOOP
            )
        })
        .all(|child| token(&child, SyntaxKind::L_PAREN).is_some())
    {
        return;
    }

    let mut types_stack = Vec::<(_, Instr)>::with_capacity(2);
    children::<Instr>(node).for_each(|instr| {
        if let Instr::Plain(plain_instr) = &instr {
            let Some(instr_name) = plain_instr.instr_name() else {
                return;
            };
            let meta = data_set::INSTR_METAS.get(instr_name.text());
            let Some(params) =
                resolve_expected_types(service, uri, symbol_table, plain_instr, meta)
            else {
                return;
            };
            let expected_count = if plain_instr.l_paren_token().is_some() {
                0
            } else {
                params.len()
            };
            let pop_count = if let Some(count) = types_stack.len().checked_sub(expected_count) {
                count
            } else {
                diags.push(Diagnostic {
                    range: helpers::rowan_range_to_lsp_range(
                        line_index,
                        instr.syntax().text_range(),
                    ),
                    severity: Some(DiagnosticSeverity::ERROR),
                    source: Some("wat".into()),
                    message: build_incorrect_operands_count_msg(expected_count, types_stack.len()),
                    ..Default::default()
                });
                0
            };
            let type_mismatches = params
                .iter()
                .zip(types_stack.drain(pop_count..))
                .filter_map(|pair| {
                    if let (
                        (OperandType::Val(expected), related),
                        (OperandType::Val(received), related_instr),
                    ) = pair
                    {
                        if expected == &received {
                            None
                        } else {
                            Some(Diagnostic {
                                range: helpers::rowan_range_to_lsp_range(
                                    line_index,
                                    related_instr.syntax().text_range(),
                                ),
                                severity: Some(DiagnosticSeverity::ERROR),
                                source: Some("wat".into()),
                                message: format!("expected type `{expected}`, found `{received}`"),
                                related_information: related.as_ref().map(|(range, message)| {
                                    vec![DiagnosticRelatedInformation {
                                        location: Location {
                                            uri: service.lookup_uri(uri),
                                            range: helpers::rowan_range_to_lsp_range(
                                                line_index, *range,
                                            ),
                                        },
                                        message: message.clone(),
                                    }]
                                }),
                                ..Default::default()
                            })
                        }
                    } else {
                        None
                    }
                });
            diags.extend(type_mismatches);

            if let Some(types) = resolve_type(service, uri, symbol_table, &instr) {
                types_stack.extend(types.into_iter().map(|ty| (ty, instr.clone())));
            }
        }
    });
}

fn resolve_type(
    service: &LanguageService,
    uri: InternUri,
    symbol_table: &SymbolTable,
    instr: &Instr,
) -> Option<Vec<OperandType>> {
    match instr {
        Instr::Block(..) => None,
        Instr::Plain(plain_instr) => {
            let instr_name = plain_instr.instr_name()?;
            match instr_name.text() {
                "call" => {
                    let idx = plain_instr.operands().next()?;
                    symbol_table
                        .find_defs(&idx.syntax().clone().into())
                        .into_iter()
                        .flatten()
                        .next()
                        .and_then(|func| service.get_func_sig(uri, func.clone().into()))
                        .map(|sig| {
                            sig.results
                                .iter()
                                .map(|ty| OperandType::Val(ty.clone()))
                                .collect()
                        })
                }
                "local.get" => {
                    let idx = plain_instr.operands().next()?;
                    symbol_table
                        .find_param_or_local_def(&idx.syntax().clone().into())
                        .and_then(|symbol| service.extract_type(symbol.green.clone()))
                        .map(|ty| vec![OperandType::Val(ty)])
                }
                "global.get" => {
                    let idx = plain_instr.operands().next()?;
                    symbol_table
                        .find_defs(&idx.syntax().clone().into())
                        .into_iter()
                        .flatten()
                        .next()
                        .and_then(|symbol| service.extract_global_type(symbol.green.clone()))
                        .map(|ty| vec![OperandType::Val(ty)])
                }
                _ => data_set::INSTR_METAS
                    .get(instr_name.text())
                    .map(|meta| meta.results.clone()),
            }
        }
    }
}

type ExpectedType = (OperandType, Option<(TextRange, String)>);
fn resolve_expected_types(
    service: &LanguageService,
    uri: InternUri,
    symbol_table: &SymbolTable,
    instr: &PlainInstr,
    meta: Option<&data_set::InstrMeta>,
) -> Option<Vec<ExpectedType>> {
    if instr.instr_name()?.text() == "call" {
        let idx = instr.operands().next()?;
        let func = symbol_table
            .find_defs(&idx.syntax().clone().into())
            .into_iter()
            .flatten()
            .next()?;
        let root = SyntaxNode::new_root(service.root(uri));
        let related = symbol_table
            .get_declared_params_and_locals(func.key.ptr.to_node(&root))
            .filter(|symbol| symbol.kind == SymbolItemKind::Param)
            .map(|symbol| {
                Some((
                    symbol.key.ptr.text_range(),
                    "parameter originally defined here".into(),
                ))
            });
        service.get_func_sig(uri, func.clone().into()).map(|sig| {
            sig.params
                .iter()
                .map(|ty| OperandType::Val(ty.0.clone()))
                .zip(related)
                .collect()
        })
    } else {
        meta.map(|meta| {
            meta.params
                .iter()
                .map(|param| (param.clone(), None))
                .collect()
        })
    }
}

fn build_incorrect_operands_count_msg(expected: usize, received: usize) -> String {
    format!(
        "expected {expected} {}, found {received}",
        if expected == 1 { "operand" } else { "operands" },
    )
}