ai_agents_core/traits/
llm.rs1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use thiserror::Error;
7
8use crate::message::ChatMessage;
9use crate::types::{LLMChunk, LLMConfig, LLMFeature, LLMResponse};
10
11#[async_trait]
17pub trait LLMProvider: Send + Sync {
18 async fn complete(
20 &self,
21 messages: &[ChatMessage],
22 config: Option<&LLMConfig>,
23 ) -> Result<LLMResponse, LLMError>;
24
25 async fn complete_stream(
27 &self,
28 messages: &[ChatMessage],
29 config: Option<&LLMConfig>,
30 ) -> Result<Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>, LLMError>;
31
32 fn provider_name(&self) -> &str;
34
35 fn supports(&self, feature: LLMFeature) -> bool;
37}
38
39#[async_trait]
41pub trait LLMCapability: Send + Sync {
42 async fn select_tool(
43 &self,
44 context: &TaskContext,
45 user_input: &str,
46 ) -> Result<ToolSelection, LLMError>;
47
48 async fn generate_tool_args(
49 &self,
50 tool_id: &str,
51 user_input: &str,
52 schema: &serde_json::Value,
53 ) -> Result<serde_json::Value, LLMError>;
54
55 async fn evaluate_yesno(
56 &self,
57 question: &str,
58 context: &TaskContext,
59 ) -> Result<(bool, String), LLMError>;
60
61 async fn classify(&self, input: &str, categories: &[String])
62 -> Result<(String, f32), LLMError>;
63
64 async fn process_task(
65 &self,
66 context: &TaskContext,
67 system_prompt: &str,
68 ) -> Result<LLMResponse, LLMError>;
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TaskContext {
74 pub current_state: Option<String>,
75 pub available_tools: Vec<String>,
76 pub memory_slots: HashMap<String, serde_json::Value>,
77 pub recent_messages: Vec<ChatMessage>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolSelection {
83 pub tool_id: String,
84 pub confidence: f32,
85 pub reasoning: Option<String>,
86}
87
88#[derive(Debug, Error)]
90pub enum LLMError {
91 #[error("API error: {message}")]
92 API {
93 message: String,
94 status: Option<u16>,
95 },
96
97 #[error("Network error: {0}")]
98 Network(String),
99
100 #[error("Rate limit exceeded: {retry_after:?}")]
101 RateLimit {
102 retry_after: Option<std::time::Duration>,
103 },
104
105 #[error("Configuration error: {0}")]
106 Config(String),
107
108 #[error("Model not found: {0}")]
109 ModelNotFound(String),
110
111 #[error("Content filtered: {0}")]
112 ContentFiltered(String),
113
114 #[error("Serialization error: {0}")]
115 Serialization(String),
116
117 #[error("Other error: {0}")]
118 Other(String),
119}
120
121impl From<serde_json::Error> for LLMError {
122 fn from(err: serde_json::Error) -> Self {
123 LLMError::Serialization(err.to_string())
124 }
125}