Skip to main content

noether_engine/llm/
mod.rs

1pub mod anthropic;
2pub mod mistral;
3pub mod openai;
4pub mod vertex;
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, thiserror::Error)]
9pub enum LlmError {
10    #[error("LLM provider error: {0}")]
11    Provider(String),
12    #[error("HTTP error: {0}")]
13    Http(String),
14    #[error("response parse error: {0}")]
15    Parse(String),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub enum Role {
20    System,
21    User,
22    Assistant,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Message {
27    pub role: Role,
28    pub content: String,
29}
30
31impl Message {
32    pub fn system(content: impl Into<String>) -> Self {
33        Self {
34            role: Role::System,
35            content: content.into(),
36        }
37    }
38
39    pub fn user(content: impl Into<String>) -> Self {
40        Self {
41            role: Role::User,
42            content: content.into(),
43        }
44    }
45
46    pub fn assistant(content: impl Into<String>) -> Self {
47        Self {
48            role: Role::Assistant,
49            content: content.into(),
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct LlmConfig {
56    pub model: String,
57    pub max_tokens: u32,
58    pub temperature: f32,
59}
60
61impl Default for LlmConfig {
62    fn default() -> Self {
63        Self {
64            // mistral-small-2503: fastest + cheapest on europe-west4 ($0.05/1K calls).
65            // Override with VERTEX_AI_MODEL=gemini-2.5-flash or =mistral-medium-3, etc.
66            model: std::env::var("VERTEX_AI_MODEL").unwrap_or_else(|_| "mistral-small-2503".into()),
67            max_tokens: 8192,
68            temperature: 0.2,
69        }
70    }
71}
72
73/// Trait for LLM text completion.
74pub trait LlmProvider: Send + Sync {
75    fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError>;
76}
77
78/// Mock LLM provider for testing.
79/// Returns the pre-configured response regardless of input.
80pub struct MockLlmProvider {
81    response: String,
82}
83
84impl MockLlmProvider {
85    pub fn new(response: impl Into<String>) -> Self {
86        Self {
87            response: response.into(),
88        }
89    }
90}
91
92impl LlmProvider for MockLlmProvider {
93    fn complete(&self, _messages: &[Message], _config: &LlmConfig) -> Result<String, LlmError> {
94        Ok(self.response.clone())
95    }
96}
97
98/// Mock LLM provider that returns responses from a queue.
99/// When the queue is exhausted, returns the fallback response.
100/// Useful for testing multi-step flows like synthesis (compose → codegen → recompose).
101pub struct SequenceMockLlmProvider {
102    responses: std::sync::Mutex<std::collections::VecDeque<String>>,
103    fallback: String,
104}
105
106impl SequenceMockLlmProvider {
107    pub fn new(responses: Vec<impl Into<String>>, fallback: impl Into<String>) -> Self {
108        Self {
109            responses: std::sync::Mutex::new(responses.into_iter().map(|s| s.into()).collect()),
110            fallback: fallback.into(),
111        }
112    }
113}
114
115impl LlmProvider for SequenceMockLlmProvider {
116    fn complete(&self, _messages: &[Message], _config: &LlmConfig) -> Result<String, LlmError> {
117        let mut queue = self.responses.lock().unwrap();
118        Ok(queue.pop_front().unwrap_or_else(|| self.fallback.clone()))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn mock_returns_configured_response() {
128        let provider = MockLlmProvider::new("hello world");
129        let result = provider
130            .complete(&[Message::user("test")], &LlmConfig::default())
131            .unwrap();
132        assert_eq!(result, "hello world");
133    }
134
135    #[test]
136    fn message_constructors() {
137        let sys = Message::system("sys");
138        assert!(matches!(sys.role, Role::System));
139        let usr = Message::user("usr");
140        assert!(matches!(usr.role, Role::User));
141        let ast = Message::assistant("ast");
142        assert!(matches!(ast.role, Role::Assistant));
143    }
144}