menta 0.0.5

Minimal Rust library for non-UI LLM and AI primitives
Documentation
use crate::{
    EmbeddingModel, EmbeddingResult, FinishReason, LanguageModel, ModelMessage, ModelRequest,
    ModelResponse, Part, ProviderRegistration, Result, ToolCall, ToolChoice, ToolDefinition,
    ToolResult, Usage,
};

#[derive(Clone, Debug)]
pub struct MockLanguageModel {
    model_id: String,
}

impl MockLanguageModel {
    fn last_user_text(&self, messages: &[ModelMessage]) -> Option<String> {
        messages.iter().rev().find_map(|message| {
            if message.role == crate::Role::User {
                let text = message.text();
                if text.is_empty() { None } else { Some(text) }
            } else {
                None
            }
        })
    }

    fn last_tool_result<'a>(&self, messages: &'a [ModelMessage]) -> Option<&'a ToolResult> {
        messages.iter().rev().find_map(|message| {
            message.parts.iter().rev().find_map(|part| match part {
                Part::ToolResult(result) => Some(result),
                _ => None,
            })
        })
    }
}

impl LanguageModel for MockLanguageModel {
    fn model_id(&self) -> &str {
        &self.model_id
    }

    fn generate<'a>(&'a self, request: &'a ModelRequest) -> crate::ModelFuture<'a, ModelResponse> {
        Box::pin(async move {
            if let Some(result) = self.last_tool_result(&request.messages) {
                return Ok(ModelResponse {
                    parts: vec![Part::Text(format!(
                        "I checked {} and found: {}",
                        result.name, result.output
                    ))],
                    finish_reason: FinishReason::Stop,
                    usage: crate::estimate_usage(
                        &request.messages,
                        &format!("{} {}", result.name, result.output),
                    ),
                    response_metadata: crate::metadata_with_provider("mock", &self.model_id),
                });
            }

            let user_text = self.last_user_text(&request.messages).unwrap_or_default();
            let user_text = user_text.trim();

            if let Some(tool_name) =
                should_call_tool(user_text, &request.tools, &request.tool_choice)
            {
                let input = user_text
                    .strip_prefix("remember ")
                    .or_else(|| user_text.strip_prefix("search "))
                    .unwrap_or(user_text)
                    .trim()
                    .to_string();

                return Ok(ModelResponse {
                    parts: vec![Part::ToolCall(ToolCall::new(tool_name, input))],
                    finish_reason: FinishReason::ToolCalls,
                    usage: crate::estimate_usage(&request.messages, user_text),
                    response_metadata: crate::metadata_with_provider("mock", &self.model_id),
                });
            }

            let text = if user_text.to_ascii_lowercase().contains("json") {
                "{\"topic\":\"llm\",\"status\":\"ok\"}".to_string()
            } else {
                format!("Mock response from {}: {}", self.model_id, user_text)
            };

            Ok(ModelResponse {
                parts: vec![Part::Text(text.clone())],
                finish_reason: FinishReason::Stop,
                usage: crate::estimate_usage(&request.messages, &text),
                response_metadata: crate::metadata_with_provider("mock", &self.model_id),
            })
        })
    }
}

fn should_call_tool(
    user_text: &str,
    tools: &[ToolDefinition],
    tool_choice: &ToolChoice,
) -> Option<String> {
    match tool_choice {
        ToolChoice::None => None,
        ToolChoice::Required(name) => tools
            .iter()
            .find(|tool| tool.name == *name)
            .map(|tool| tool.name.clone()),
        ToolChoice::Auto => {
            if user_text.starts_with("remember ") {
                tools
                    .iter()
                    .find(|tool| tool.name == "add_resource")
                    .map(|tool| tool.name.clone())
            } else if user_text.starts_with("search ") {
                tools
                    .iter()
                    .find(|tool| tool.name == "get_information")
                    .map(|tool| tool.name.clone())
            } else {
                None
            }
        }
    }
}

#[derive(Clone, Debug)]
pub struct MockEmbeddingModel {
    model_id: String,
}

impl EmbeddingModel for MockEmbeddingModel {
    fn model_id(&self) -> &str {
        &self.model_id
    }

    fn embed(&self, value: &str) -> Result<EmbeddingResult> {
        let mut buckets = vec![0.0; 8];

        for ch in value.chars().flat_map(|c| c.to_lowercase()) {
            if ch.is_ascii_alphabetic() {
                let index = (ch as u8 - b'a') as usize % buckets.len();
                buckets[index] += 1.0;
            }
        }

        Ok(EmbeddingResult {
            embedding: buckets,
            usage: Usage {
                input_tokens: crate::count_tokens(value),
                output_tokens: 0,
            },
        })
    }
}

fn mock_language_model(model_id: &str) -> Result<Box<dyn LanguageModel>> {
    Ok(Box::new(MockLanguageModel {
        model_id: model_id.to_string(),
    }) as Box<dyn LanguageModel>)
}

fn mock_embedding_model(model_id: &str) -> Result<Box<dyn EmbeddingModel>> {
    Ok(Box::new(MockEmbeddingModel {
        model_id: model_id.to_string(),
    }) as Box<dyn EmbeddingModel>)
}

inventory::submit! {
    ProviderRegistration {
        id: "mock",
        language_model: mock_language_model,
        embedding_model: mock_embedding_model,
    }
}