ecl_core/llm/
provider.rs

1//! LLM provider abstraction.
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use crate::Result;
7
8/// Abstraction over LLM providers (Claude, GPT, etc.).
9///
10/// This trait allows swapping LLM backends without changing workflow code.
11#[async_trait]
12pub trait LlmProvider: Send + Sync {
13    /// Completes a prompt and returns the full response.
14    ///
15    /// This is a blocking call that waits for the entire response.
16    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
17
18    /// Completes a prompt with streaming response.
19    ///
20    /// Returns a stream of response chunks as they arrive.
21    async fn complete_streaming(&self, request: CompletionRequest) -> Result<CompletionStream>;
22}
23
24/// A request to complete a prompt.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CompletionRequest {
27    /// System prompt (context/instructions)
28    pub system_prompt: Option<String>,
29
30    /// Conversation messages
31    pub messages: Vec<Message>,
32
33    /// Maximum tokens to generate
34    pub max_tokens: u32,
35
36    /// Temperature (0.0 = deterministic, 1.0 = creative)
37    pub temperature: Option<f32>,
38
39    /// Stop sequences
40    pub stop_sequences: Vec<String>,
41}
42
43impl CompletionRequest {
44    /// Creates a new completion request with default settings.
45    pub fn new(messages: Vec<Message>) -> Self {
46        Self {
47            system_prompt: None,
48            messages,
49            max_tokens: 1024,
50            temperature: None,
51            stop_sequences: Vec::new(),
52        }
53    }
54
55    /// Sets the system prompt.
56    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
57        self.system_prompt = Some(prompt.into());
58        self
59    }
60
61    /// Sets the maximum tokens.
62    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
63        self.max_tokens = max_tokens;
64        self
65    }
66
67    /// Sets the temperature.
68    pub fn with_temperature(mut self, temperature: f32) -> Self {
69        self.temperature = Some(temperature);
70        self
71    }
72
73    /// Adds a stop sequence.
74    pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
75        self.stop_sequences.push(sequence.into());
76        self
77    }
78}
79
80/// A message in the conversation.
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct Message {
83    /// Role of the message sender
84    pub role: Role,
85
86    /// Message content
87    pub content: String,
88}
89
90impl Message {
91    /// Creates a user message.
92    pub fn user(content: impl Into<String>) -> Self {
93        Self {
94            role: Role::User,
95            content: content.into(),
96        }
97    }
98
99    /// Creates an assistant message.
100    pub fn assistant(content: impl Into<String>) -> Self {
101        Self {
102            role: Role::Assistant,
103            content: content.into(),
104        }
105    }
106}
107
108/// Role of a message sender.
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
110#[serde(rename_all = "lowercase")]
111pub enum Role {
112    /// User message
113    User,
114    /// Assistant message
115    Assistant,
116}
117
118/// Response from an LLM completion.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct CompletionResponse {
121    /// Generated content
122    pub content: String,
123
124    /// Token usage statistics
125    pub tokens_used: TokenUsage,
126
127    /// Why the model stopped generating
128    pub stop_reason: StopReason,
129}
130
131/// Token usage statistics.
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
133pub struct TokenUsage {
134    /// Input tokens consumed
135    pub input: u64,
136
137    /// Output tokens generated
138    pub output: u64,
139}
140
141impl TokenUsage {
142    /// Total tokens used (input + output).
143    pub fn total(&self) -> u64 {
144        self.input + self.output
145    }
146}
147
148/// Reason why the model stopped generating.
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
150#[serde(rename_all = "snake_case")]
151#[non_exhaustive]
152pub enum StopReason {
153    /// Reached the end of the response naturally
154    EndTurn,
155
156    /// Hit the maximum token limit
157    MaxTokens,
158
159    /// Encountered a stop sequence
160    StopSequence,
161}
162
163/// Streaming response from an LLM completion.
164///
165/// This is a placeholder for now; full implementation in Phase 3.
166pub struct CompletionStream {
167    // Future: implement streaming using tokio::sync::mpsc or similar
168    _private: (),
169}
170
171#[cfg(test)]
172#[allow(clippy::unwrap_used)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_message_constructors() {
178        let user_msg = Message::user("Hello");
179        assert_eq!(user_msg.role, Role::User);
180        assert_eq!(user_msg.content, "Hello");
181
182        let asst_msg = Message::assistant("Hi there");
183        assert_eq!(asst_msg.role, Role::Assistant);
184        assert_eq!(asst_msg.content, "Hi there");
185    }
186
187    #[test]
188    fn test_completion_request_builder() {
189        let request = CompletionRequest::new(vec![Message::user("Test")])
190            .with_system_prompt("You are helpful")
191            .with_max_tokens(2048)
192            .with_temperature(0.7)
193            .with_stop_sequence("\n\n");
194
195        assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
196        assert_eq!(request.max_tokens, 2048);
197        assert_eq!(request.temperature, Some(0.7));
198        assert_eq!(request.stop_sequences, vec!["\n\n"]);
199    }
200
201    #[test]
202    fn test_token_usage_total() {
203        let usage = TokenUsage {
204            input: 100,
205            output: 200,
206        };
207        assert_eq!(usage.total(), 300);
208    }
209
210    #[test]
211    fn test_message_serialization() {
212        let msg = Message::user("test content");
213        let json = serde_json::to_string(&msg).unwrap();
214        let deserialized: Message = serde_json::from_str(&json).unwrap();
215        assert_eq!(msg, deserialized);
216    }
217}