1use crate::types::*;
2use serde::{Deserialize, Serialize};
3use std::future::Future;
4use std::{collections::HashMap, pin::Pin};
5mod anthropic;
6mod deepseek;
7mod google;
8mod openai;
9
10pub use anthropic::{Anthropic, AnthropicModel};
11pub use deepseek::{DeepSeek, DeepSeekModel};
12pub use google::{GoogleAI, GoogleModel};
13pub use openai::{OpenAI, OpenAIModel};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct LLMResult {
17 pub text: String,
18 pub metadata: Option<HashMap<String, serde_json::Value>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct LLMOptions {
23 pub temperature: Option<f32>,
24 pub max_tokens: Option<u32>,
25 pub top_p: Option<f32>,
26 pub top_k: Option<u32>,
27 pub frequency_penalty: Option<f32>,
28 pub presence_penalty: Option<f32>,
29 pub stop_sequences: Option<Vec<String>>,
30 pub seed: Option<u64>,
31 pub response_format: Option<ResponseFormat>,
32}
33
34#[derive(Debug, Clone)]
35pub enum ResponseFormat {
36 Text,
37 Json,
38 JsonSchema { schema: serde_json::Value },
39}
40
41impl Default for LLMOptions {
42 fn default() -> Self {
43 Self {
44 temperature: Some(0.7),
45 max_tokens: Some(4096),
46 top_p: Some(0.95),
47 top_k: None,
48 frequency_penalty: None,
49 presence_penalty: None,
50 stop_sequences: None,
51 seed: None,
52 response_format: None,
53 }
54 }
55}
56
57pub trait LLM: Send + Sync {
58 fn generate(&self, prompt: &str) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>>;
59
60 fn generate_with_options(
61 &self,
62 prompt: &str,
63 options: LLMOptions,
64 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>>;
65
66 fn generate_batch(
67 &self,
68 prompts: Vec<String>,
69 ) -> Pin<Box<dyn Future<Output = Result<Vec<LLMResult>>> + Send + '_>> {
70 Box::pin(async move {
71 let mut results = Vec::new();
72 for prompt in prompts {
73 results.push(self.generate(&prompt).await?);
74 }
75 Ok(results)
76 })
77 }
78
79 fn chat(
80 &self,
81 messages: Vec<ChatMessage>,
82 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
83 Box::pin(async move {
84 let prompt = messages
85 .iter()
86 .map(|m| format!("{}: {}", m.role, m.content))
87 .collect::<Vec<_>>()
88 .join("\n");
89 self.generate(&prompt).await
90 })
91 }
92
93 fn get_model_name(&self) -> &str;
94 fn get_provider_name(&self) -> &str;
95 fn supports_function_calling(&self) -> bool {
96 false
97 }
98 fn supports_json_mode(&self) -> bool {
99 false
100 }
101 fn max_context_length(&self) -> Option<usize> {
102 None
103 }
104
105 fn get_provider_enum(&self) -> ModelProvider;
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110 pub role: String,
111 pub content: String,
112 pub name: Option<String>,
113 pub tool_calls: Option<Vec<ToolCall>>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ToolCall {
118 pub id: String,
119 pub r#type: String,
120 pub function: FunctionCall,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct FunctionCall {
125 pub name: String,
126 pub arguments: String,
127}
128
129impl ChatMessage {
130 pub fn user(content: &str) -> Self {
131 Self {
132 role: "user".to_string(),
133 content: content.to_string(),
134 name: None,
135 tool_calls: None,
136 }
137 }
138
139 pub fn assistant(content: &str) -> Self {
140 Self {
141 role: "assistant".to_string(),
142 content: content.to_string(),
143 name: None,
144 tool_calls: None,
145 }
146 }
147
148 pub fn system(content: &str) -> Self {
149 Self {
150 role: "system".to_string(),
151 content: content.to_string(),
152 name: None,
153 tool_calls: None,
154 }
155 }
156}