Skip to main content

ai_agents_core/traits/
llm.rs

1//! LLM provider traits
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use thiserror::Error;
7
8use crate::message::ChatMessage;
9use crate::types::{LLMChunk, LLMConfig, LLMFeature, LLMResponse};
10
11/// Core LLM provider trait.
12///
13/// Implement this to integrate a custom LLM backend. Most users can use
14/// `UnifiedLLMProvider` which supports OpenAI, Anthropic, and other providers
15/// out of the box.
16#[async_trait]
17pub trait LLMProvider: Send + Sync {
18    /// Send messages and get a complete response.
19    async fn complete(
20        &self,
21        messages: &[ChatMessage],
22        config: Option<&LLMConfig>,
23    ) -> Result<LLMResponse, LLMError>;
24
25    /// Send messages and get a streaming response.
26    async fn complete_stream(
27        &self,
28        messages: &[ChatMessage],
29        config: Option<&LLMConfig>,
30    ) -> Result<Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>, LLMError>;
31
32    /// Provider identifier (e.g. `"openai"`, `"anthropic"`).
33    fn provider_name(&self) -> &str;
34
35    /// Check if this provider supports a given feature.
36    fn supports(&self, feature: LLMFeature) -> bool;
37}
38
39/// Higher-level LLM capabilities for agent operations
40#[async_trait]
41pub trait LLMCapability: Send + Sync {
42    async fn select_tool(
43        &self,
44        context: &TaskContext,
45        user_input: &str,
46    ) -> Result<ToolSelection, LLMError>;
47
48    async fn generate_tool_args(
49        &self,
50        tool_id: &str,
51        user_input: &str,
52        schema: &serde_json::Value,
53    ) -> Result<serde_json::Value, LLMError>;
54
55    async fn evaluate_yesno(
56        &self,
57        question: &str,
58        context: &TaskContext,
59    ) -> Result<(bool, String), LLMError>;
60
61    async fn classify(&self, input: &str, categories: &[String])
62    -> Result<(String, f32), LLMError>;
63
64    async fn process_task(
65        &self,
66        context: &TaskContext,
67        system_prompt: &str,
68    ) -> Result<LLMResponse, LLMError>;
69}
70
71/// Task context for LLM operations
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TaskContext {
74    pub current_state: Option<String>,
75    pub available_tools: Vec<String>,
76    pub memory_slots: HashMap<String, serde_json::Value>,
77    pub recent_messages: Vec<ChatMessage>,
78}
79
80/// Tool selection result
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolSelection {
83    pub tool_id: String,
84    pub confidence: f32,
85    pub reasoning: Option<String>,
86}
87
88/// LLM error types
89#[derive(Debug, Error)]
90pub enum LLMError {
91    #[error("API error: {message}")]
92    API {
93        message: String,
94        status: Option<u16>,
95    },
96
97    #[error("Network error: {0}")]
98    Network(String),
99
100    #[error("Rate limit exceeded: {retry_after:?}")]
101    RateLimit {
102        retry_after: Option<std::time::Duration>,
103    },
104
105    #[error("Configuration error: {0}")]
106    Config(String),
107
108    #[error("Model not found: {0}")]
109    ModelNotFound(String),
110
111    #[error("Content filtered: {0}")]
112    ContentFiltered(String),
113
114    #[error("Serialization error: {0}")]
115    Serialization(String),
116
117    #[error("Other error: {0}")]
118    Other(String),
119}
120
121impl From<serde_json::Error> for LLMError {
122    fn from(err: serde_json::Error) -> Self {
123        LLMError::Serialization(err.to_string())
124    }
125}