neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Base trait for LLM implementations

use std::collections::HashMap;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};

use crate::error::Result;

/// A message in a conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    /// The role of the message sender (system, user, assistant)
    pub role: String,
    /// The content of the message
    pub content: String,
    /// Optional name for the sender
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
}

impl Message {
    /// Create a new system message
    pub fn system(content: impl Into<String>) -> Self {
        Self {
            role: "system".to_string(),
            content: content.into(),
            name: None,
        }
    }

    /// Create a new user message
    pub fn user(content: impl Into<String>) -> Self {
        Self {
            role: "user".to_string(),
            content: content.into(),
            name: None,
        }
    }

    /// Create a new assistant message
    pub fn assistant(content: impl Into<String>) -> Self {
        Self {
            role: "assistant".to_string(),
            content: content.into(),
            name: None,
        }
    }

    /// Set the name for this message
    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }
}

/// Response format specification
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResponseFormat {
    /// The type of response format ("text" or "json_object")
    #[serde(rename = "type")]
    pub format_type: String,
}

impl ResponseFormat {
    /// Create a text response format
    pub fn text() -> Self {
        Self {
            format_type: "text".to_string(),
        }
    }

    /// Create a JSON object response format
    pub fn json_object() -> Self {
        Self {
            format_type: "json_object".to_string(),
        }
    }
}

/// Tool definition for function calling
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
    /// The type of tool (always "function" currently)
    #[serde(rename = "type")]
    pub tool_type: String,
    /// The function definition
    pub function: FunctionDef,
}

/// Function definition for tools
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
    /// Name of the function
    pub name: String,
    /// Description of the function
    pub description: String,
    /// Parameters schema
    pub parameters: serde_json::Value,
}

/// A tool call made by the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
    /// Name of the function called
    pub name: String,
    /// Arguments passed to the function
    pub arguments: HashMap<String, serde_json::Value>,
}

/// Response from an LLM
#[derive(Debug, Clone)]
pub enum LlmResponse {
    /// A text response
    Text(String),
    /// A response with tool calls.
    WithToolCalls {
        /// Optional text content alongside tool calls.
        content: Option<String>,
        /// The tool calls requested by the model.
        tool_calls: Vec<ToolCall>,
    },
}

impl LlmResponse {
    /// Get the text content of the response
    pub fn text(&self) -> Option<&str> {
        match self {
            LlmResponse::Text(s) => Some(s),
            LlmResponse::WithToolCalls { content, .. } => content.as_deref(),
        }
    }

    /// Get the text content, returning an empty string if none
    pub fn text_or_empty(&self) -> &str {
        self.text().unwrap_or("")
    }
}

/// Base trait for LLM implementations
#[async_trait]
pub trait LlmBase: Send + Sync {
    /// Generate a response from the LLM
    ///
    /// # Arguments
    /// * `messages` - The conversation history
    /// * `response_format` - Optional response format specification
    /// * `tools` - Optional list of tools available to the LLM
    /// * `tool_choice` - How to choose tools ("auto", "none", or specific tool name)
    async fn generate_response(
        &self,
        messages: Vec<Message>,
        response_format: Option<ResponseFormat>,
        tools: Option<Vec<Tool>>,
        tool_choice: Option<String>,
    ) -> Result<LlmResponse>;

    /// Generate a simple text response
    async fn generate(&self, messages: Vec<Message>) -> Result<String> {
        let response = self.generate_response(messages, None, None, None).await?;
        Ok(response.text_or_empty().to_string())
    }

    /// Generate a JSON response
    async fn generate_json(&self, messages: Vec<Message>) -> Result<String> {
        let response = self
            .generate_response(messages, Some(ResponseFormat::json_object()), None, None)
            .await?;
        Ok(response.text_or_empty().to_string())
    }
}