noether_engine/llm/
mod.rs1pub mod anthropic;
2pub mod mistral;
3pub mod openai;
4pub mod vertex;
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, thiserror::Error)]
9pub enum LlmError {
10 #[error("LLM provider error: {0}")]
11 Provider(String),
12 #[error("HTTP error: {0}")]
13 Http(String),
14 #[error("response parse error: {0}")]
15 Parse(String),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub enum Role {
20 System,
21 User,
22 Assistant,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Message {
27 pub role: Role,
28 pub content: String,
29}
30
31impl Message {
32 pub fn system(content: impl Into<String>) -> Self {
33 Self {
34 role: Role::System,
35 content: content.into(),
36 }
37 }
38
39 pub fn user(content: impl Into<String>) -> Self {
40 Self {
41 role: Role::User,
42 content: content.into(),
43 }
44 }
45
46 pub fn assistant(content: impl Into<String>) -> Self {
47 Self {
48 role: Role::Assistant,
49 content: content.into(),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct LlmConfig {
56 pub model: String,
57 pub max_tokens: u32,
58 pub temperature: f32,
59}
60
61impl Default for LlmConfig {
62 fn default() -> Self {
63 Self {
64 model: std::env::var("VERTEX_AI_MODEL").unwrap_or_else(|_| "mistral-small-2503".into()),
67 max_tokens: 8192,
68 temperature: 0.2,
69 }
70 }
71}
72
73pub trait LlmProvider: Send + Sync {
75 fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError>;
76}
77
78pub struct MockLlmProvider {
81 response: String,
82}
83
84impl MockLlmProvider {
85 pub fn new(response: impl Into<String>) -> Self {
86 Self {
87 response: response.into(),
88 }
89 }
90}
91
92impl LlmProvider for MockLlmProvider {
93 fn complete(&self, _messages: &[Message], _config: &LlmConfig) -> Result<String, LlmError> {
94 Ok(self.response.clone())
95 }
96}
97
98pub struct SequenceMockLlmProvider {
102 responses: std::sync::Mutex<std::collections::VecDeque<String>>,
103 fallback: String,
104}
105
106impl SequenceMockLlmProvider {
107 pub fn new(responses: Vec<impl Into<String>>, fallback: impl Into<String>) -> Self {
108 Self {
109 responses: std::sync::Mutex::new(responses.into_iter().map(|s| s.into()).collect()),
110 fallback: fallback.into(),
111 }
112 }
113}
114
115impl LlmProvider for SequenceMockLlmProvider {
116 fn complete(&self, _messages: &[Message], _config: &LlmConfig) -> Result<String, LlmError> {
117 let mut queue = self.responses.lock().unwrap();
118 Ok(queue.pop_front().unwrap_or_else(|| self.fallback.clone()))
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn mock_returns_configured_response() {
128 let provider = MockLlmProvider::new("hello world");
129 let result = provider
130 .complete(&[Message::user("test")], &LlmConfig::default())
131 .unwrap();
132 assert_eq!(result, "hello world");
133 }
134
135 #[test]
136 fn message_constructors() {
137 let sys = Message::system("sys");
138 assert!(matches!(sys.role, Role::System));
139 let usr = Message::user("usr");
140 assert!(matches!(usr.role, Role::User));
141 let ast = Message::assistant("ast");
142 assert!(matches!(ast.role, Role::Assistant));
143 }
144}