1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use crate::Result;
7
8#[async_trait]
12pub trait LlmProvider: Send + Sync {
13 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
17
18 async fn complete_streaming(&self, request: CompletionRequest) -> Result<CompletionStream>;
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CompletionRequest {
27 pub system_prompt: Option<String>,
29
30 pub messages: Vec<Message>,
32
33 pub max_tokens: u32,
35
36 pub temperature: Option<f32>,
38
39 pub stop_sequences: Vec<String>,
41}
42
43impl CompletionRequest {
44 pub fn new(messages: Vec<Message>) -> Self {
46 Self {
47 system_prompt: None,
48 messages,
49 max_tokens: 1024,
50 temperature: None,
51 stop_sequences: Vec::new(),
52 }
53 }
54
55 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
57 self.system_prompt = Some(prompt.into());
58 self
59 }
60
61 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
63 self.max_tokens = max_tokens;
64 self
65 }
66
67 pub fn with_temperature(mut self, temperature: f32) -> Self {
69 self.temperature = Some(temperature);
70 self
71 }
72
73 pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
75 self.stop_sequences.push(sequence.into());
76 self
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct Message {
83 pub role: Role,
85
86 pub content: String,
88}
89
90impl Message {
91 pub fn user(content: impl Into<String>) -> Self {
93 Self {
94 role: Role::User,
95 content: content.into(),
96 }
97 }
98
99 pub fn assistant(content: impl Into<String>) -> Self {
101 Self {
102 role: Role::Assistant,
103 content: content.into(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
110#[serde(rename_all = "lowercase")]
111pub enum Role {
112 User,
114 Assistant,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct CompletionResponse {
121 pub content: String,
123
124 pub tokens_used: TokenUsage,
126
127 pub stop_reason: StopReason,
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
133pub struct TokenUsage {
134 pub input: u64,
136
137 pub output: u64,
139}
140
141impl TokenUsage {
142 pub fn total(&self) -> u64 {
144 self.input + self.output
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
150#[serde(rename_all = "snake_case")]
151#[non_exhaustive]
152pub enum StopReason {
153 EndTurn,
155
156 MaxTokens,
158
159 StopSequence,
161}
162
163pub struct CompletionStream {
167 _private: (),
169}
170
171#[cfg(test)]
172#[allow(clippy::unwrap_used)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_message_constructors() {
178 let user_msg = Message::user("Hello");
179 assert_eq!(user_msg.role, Role::User);
180 assert_eq!(user_msg.content, "Hello");
181
182 let asst_msg = Message::assistant("Hi there");
183 assert_eq!(asst_msg.role, Role::Assistant);
184 assert_eq!(asst_msg.content, "Hi there");
185 }
186
187 #[test]
188 fn test_completion_request_builder() {
189 let request = CompletionRequest::new(vec![Message::user("Test")])
190 .with_system_prompt("You are helpful")
191 .with_max_tokens(2048)
192 .with_temperature(0.7)
193 .with_stop_sequence("\n\n");
194
195 assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
196 assert_eq!(request.max_tokens, 2048);
197 assert_eq!(request.temperature, Some(0.7));
198 assert_eq!(request.stop_sequences, vec!["\n\n"]);
199 }
200
201 #[test]
202 fn test_token_usage_total() {
203 let usage = TokenUsage {
204 input: 100,
205 output: 200,
206 };
207 assert_eq!(usage.total(), 300);
208 }
209
210 #[test]
211 fn test_message_serialization() {
212 let msg = Message::user("test content");
213 let json = serde_json::to_string(&msg).unwrap();
214 let deserialized: Message = serde_json::from_str(&json).unwrap();
215 assert_eq!(msg, deserialized);
216 }
217}