emmylua_code_analysis 0.22.0

A library for analyzing lua code.
Documentation
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::{collections::HashMap, sync::Arc};

use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaDocFuncType};
use rowan::TextSize;

use super::return_rows;
use crate::db_index::signature::async_state::AsyncState;
use crate::{
    FileId,
    db_index::{LuaFunctionType, LuaType},
};
use crate::{LuaAttributeUse, SemanticModel, first_param_may_not_self};

#[derive(Debug)]
pub struct LuaSignature {
    pub generic_params: Vec<Arc<LuaGenericParamInfo>>,
    pub overloads: Vec<Arc<LuaFunctionType>>,
    pub param_docs: HashMap<usize, LuaDocParamInfo>,
    pub params: Vec<String>,
    pub return_docs: Vec<LuaDocReturnInfo>,
    pub return_overloads: Vec<LuaDocReturnOverloadInfo>,
    pub resolve_return: SignatureReturnStatus,
    pub is_colon_define: bool,
    pub async_state: AsyncState,
    pub nodiscard: Option<LuaNoDiscard>,
    pub is_vararg: bool,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LuaNoDiscard {
    NoDiscard,
    NoDiscardWithMessage(Box<String>),
}

impl Default for LuaSignature {
    fn default() -> Self {
        Self::new()
    }
}

impl LuaSignature {
    pub fn new() -> Self {
        Self {
            generic_params: Vec::new(),
            overloads: Vec::new(),
            param_docs: HashMap::new(),
            params: Vec::new(),
            return_docs: Vec::new(),
            return_overloads: Vec::new(),
            resolve_return: SignatureReturnStatus::UnResolve,
            is_colon_define: false,
            async_state: AsyncState::None,
            nodiscard: None,
            is_vararg: false,
        }
    }

    pub fn is_generic(&self) -> bool {
        !self.generic_params.is_empty()
    }

    pub fn is_resolve_return(&self) -> bool {
        self.resolve_return != SignatureReturnStatus::UnResolve
    }

    pub fn get_type_params(&self) -> Vec<(String, Option<LuaType>)> {
        let mut type_params = Vec::new();
        for (idx, param_name) in self.params.iter().enumerate() {
            if let Some(param_info) = self.param_docs.get(&idx) {
                type_params.push((param_name.clone(), Some(param_info.type_ref.clone())));
            } else {
                type_params.push((param_name.clone(), None));
            }
        }

        type_params
    }

    pub fn find_param_idx(&self, param_name: &str) -> Option<usize> {
        self.params.iter().position(|name| name == param_name)
    }

    pub fn get_param_info_by_name(&self, param_name: &str) -> Option<&LuaDocParamInfo> {
        // fast enough
        let idx = self.params.iter().position(|name| name == param_name)?;
        self.param_docs.get(&idx)
    }

    pub fn get_param_name_by_id(&self, idx: usize) -> Option<String> {
        if idx < self.params.len() {
            return Some(self.params[idx].clone());
        } else if let Some(name) = self.params.last()
            && name == "..."
        {
            return Some(name.clone());
        }

        None
    }

    pub fn get_param_info_by_id(&self, idx: usize) -> Option<&LuaDocParamInfo> {
        if idx < self.params.len() {
            return self.param_docs.get(&idx);
        } else if let Some(name) = self.params.last()
            && name == "..."
        {
            return self.param_docs.get(&(self.params.len() - 1));
        }

        None
    }

    pub fn get_return_type(&self) -> LuaType {
        return_rows::get_return_type(&self.return_docs, &self.return_overloads)
    }

    pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType {
        return_rows::get_overload_row_slot(row, idx)
    }

    pub(crate) fn row_to_return_type(row: Vec<LuaType>) -> LuaType {
        return_rows::row_to_return_type(row)
    }

    pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec<LuaType> {
        return_rows::return_type_to_row(return_type)
    }

    pub fn is_method(&self, semantic_model: &SemanticModel, owner_type: Option<&LuaType>) -> bool {
        if self.is_colon_define {
            return true;
        }

        if let Some(param_info) = self.get_param_info_by_id(0) {
            let param_type = &param_info.type_ref;
            if param_type.is_self_infer() {
                return true;
            }
            match owner_type {
                Some(owner_type) => {
                    // 一些类型不应该被视为 method
                    if matches!(owner_type, LuaType::Ref(_) | LuaType::Def(_))
                        && first_param_may_not_self(param_type)
                    {
                        return false;
                    }

                    semantic_model
                        .type_check(owner_type, &param_info.type_ref)
                        .is_ok()
                }
                None => param_info.name == "self",
            }
        } else {
            false
        }
    }

    pub fn to_doc_func_type(&self) -> Arc<LuaFunctionType> {
        let params = self.get_type_params();
        let return_type = self.get_return_type();
        let is_vararg = self.is_vararg;
        let func_type = LuaFunctionType::new(
            self.async_state,
            self.is_colon_define,
            is_vararg,
            params,
            return_type,
        );
        Arc::new(func_type)
    }

    pub fn to_call_operator_func_type(&self) -> Arc<LuaFunctionType> {
        let mut params = self.get_type_params();
        if !params.is_empty() && !self.is_colon_define {
            params.remove(0);
        }

        let return_type = self.get_return_type();
        let func_type =
            LuaFunctionType::new(self.async_state, false, self.is_vararg, params, return_type);
        Arc::new(func_type)
    }
}

#[derive(Debug)]
pub struct LuaDocParamInfo {
    pub name: String,
    pub type_ref: LuaType,
    pub nullable: bool,
    pub description: Option<String>,
    pub attributes: Option<Vec<LuaAttributeUse>>,
}

impl LuaDocParamInfo {
    pub fn get_attribute_by_name(&self, name: &str) -> Option<&LuaAttributeUse> {
        self.attributes
            .iter()
            .flatten()
            .find(|attr| attr.id.get_name() == name)
    }
}

#[derive(Debug, Clone)]
pub struct LuaDocReturnInfo {
    pub name: Option<String>,
    pub type_ref: LuaType,
    pub description: Option<String>,
    pub attributes: Option<Vec<LuaAttributeUse>>,
}

#[derive(Debug, Clone)]
pub struct LuaDocReturnOverloadInfo {
    pub type_refs: Vec<LuaType>,
    pub description: Option<String>,
}

#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
pub struct LuaSignatureId {
    file_id: FileId,
    position: TextSize,
}

impl Serialize for LuaSignatureId {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let value = format!("{}|{}", self.file_id.id, u32::from(self.position));
        serializer.serialize_str(&value)
    }
}

impl<'de> Deserialize<'de> for LuaSignatureId {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct LuaSignatureIdVisitor;

        impl<'de> Visitor<'de> for LuaSignatureIdVisitor {
            type Value = LuaSignatureId;

            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                formatter.write_str("a string with format 'file_id:position'")
            }

            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
            where
                E: de::Error,
            {
                let parts: Vec<&str> = value.split('|').collect();
                if parts.len() != 2 {
                    return Err(E::custom("expected format 'file_id:position'"));
                }

                let file_id = FileId {
                    id: parts[0]
                        .parse()
                        .map_err(|e| E::custom(format!("invalid file_id: {}", e)))?,
                };
                let position = TextSize::new(
                    parts[1]
                        .parse()
                        .map_err(|e| E::custom(format!("invalid position: {}", e)))?,
                );

                Ok(LuaSignatureId { file_id, position })
            }
        }

        deserializer.deserialize_str(LuaSignatureIdVisitor)
    }
}

impl LuaSignatureId {
    pub fn from_closure(file_id: FileId, closure: &LuaClosureExpr) -> Self {
        Self {
            file_id,
            position: closure.get_position(),
        }
    }

    pub fn from_doc_func(file_id: FileId, func_type: &LuaDocFuncType) -> Self {
        Self {
            file_id,
            position: func_type.get_position(),
        }
    }

    pub fn get_file_id(&self) -> FileId {
        self.file_id
    }

    pub fn get_position(&self) -> TextSize {
        self.position
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SignatureReturnStatus {
    UnResolve,
    DocResolve,
    InferResolve,
}

#[derive(Debug, Clone)]
pub struct LuaGenericParamInfo {
    pub name: String,
    pub constraint: Option<LuaType>,
    pub attributes: Option<Vec<LuaAttributeUse>>,
}

impl LuaGenericParamInfo {
    pub fn new(
        name: String,
        constraint: Option<LuaType>,
        attributes: Option<Vec<LuaAttributeUse>>,
    ) -> Self {
        Self {
            name,
            constraint,
            attributes,
        }
    }
}