1use std::pin::Pin;
15
16use async_trait::async_trait;
17use futures_core::Stream;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20
21use crate::error::Result;
22use crate::tool::Tool;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct LlmConfig {
30 pub provider: String,
32 pub model: String,
34 pub temperature: f32,
36 pub max_tokens: u32,
38}
39
40impl LlmConfig {
41 pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
43 Self {
44 provider: provider.into(),
45 model: model.into(),
46 temperature: 0.7,
47 max_tokens: 4096,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(tag = "role", rename_all = "snake_case")]
58pub enum Message {
59 System { content: String },
62
63 User { content: String },
66
67 Assistant {
70 content: Option<String>,
71 #[serde(default, skip_serializing_if = "Vec::is_empty")]
72 tool_calls: Vec<ToolCall>,
73 },
74
75 #[serde(rename = "tool")]
78 ToolResult {
79 tool_call_id: String,
80 content: String,
81 },
82}
83
84impl Message {
85 pub fn system(content: impl Into<String>) -> Self {
87 Self::System {
88 content: content.into(),
89 }
90 }
91
92 pub fn user(content: impl Into<String>) -> Self {
94 Self::User {
95 content: content.into(),
96 }
97 }
98
99 pub fn assistant(content: impl Into<String>) -> Self {
101 Self::Assistant {
102 content: Some(content.into()),
103 tool_calls: vec![],
104 }
105 }
106
107 pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
109 Self::Assistant {
110 content: None,
111 tool_calls,
112 }
113 }
114
115 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
117 Self::ToolResult {
118 tool_call_id: tool_call_id.into(),
119 content: content.into(),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ToolCall {
127 pub id: String,
129 pub name: String,
131 pub arguments: Value,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ToolDefinition {
138 pub name: String,
140 pub description: String,
142 pub parameters: Value,
144}
145
146impl ToolDefinition {
147 pub fn from_tool(tool: &dyn Tool) -> Self {
151 Self {
152 name: tool.name().to_string(),
153 description: tool.description().to_string(),
154 parameters: tool.parameters(),
155 }
156 }
157}
158
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
161pub struct TokenUsage {
162 pub input_tokens: u32,
164 pub output_tokens: u32,
166}
167
168#[derive(Debug, Clone)]
170pub struct LlmResponse {
171 pub content: Option<String>,
173 pub tool_calls: Vec<ToolCall>,
175 pub usage: TokenUsage,
177}
178
179#[derive(Debug, Clone)]
181pub enum LlmChunk {
182 Text(String),
184 ToolCallStart { id: String, name: String },
186 ToolCallDelta { id: String, arguments_delta: String },
188 Done,
190}
191
192#[async_trait]
200pub trait LlmProvider: Send + Sync {
201 async fn chat(
203 &self,
204 messages: Vec<Message>,
205 tools: Vec<ToolDefinition>,
206 config: &LlmConfig,
207 ) -> Result<LlmResponse>;
208
209 async fn chat_stream(
211 &self,
212 messages: Vec<Message>,
213 tools: Vec<ToolDefinition>,
214 config: &LlmConfig,
215 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>>;
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_llm_provider_is_object_safe() {
224 fn _assert_object_safe(_: &dyn LlmProvider) {}
225 }
226
227 #[test]
228 fn test_llm_config_new_defaults() {
229 let config = LlmConfig::new("openai", "gpt-4");
230 assert_eq!(config.provider, "openai");
231 assert_eq!(config.model, "gpt-4");
232 assert!((config.temperature - 0.7).abs() < f32::EPSILON);
233 assert_eq!(config.max_tokens, 4096);
234 }
235
236 #[test]
237 fn test_llm_config_serialization() {
238 let config = LlmConfig::new("anthropic", "claude-sonnet-4-6");
239 let json = serde_json::to_string(&config).unwrap();
240 let deserialized: LlmConfig = serde_json::from_str(&json).unwrap();
241 assert_eq!(deserialized.provider, "anthropic");
242 assert_eq!(deserialized.model, "claude-sonnet-4-6");
243 }
244
245 #[test]
248 fn test_message_system_openai_format() {
249 let msg = Message::system("You are helpful");
250 let json = serde_json::to_value(&msg).unwrap();
251 assert_eq!(json["role"], "system");
252 assert_eq!(json["content"], "You are helpful");
253 assert_eq!(json.as_object().unwrap().len(), 2); }
255
256 #[test]
257 fn test_message_user_openai_format() {
258 let msg = Message::user("Hello");
259 let json = serde_json::to_value(&msg).unwrap();
260 assert_eq!(json["role"], "user");
261 assert_eq!(json["content"], "Hello");
262 }
263
264 #[test]
265 fn test_message_assistant_text_only_format() {
266 let msg = Message::assistant("The answer is 42");
267 let json = serde_json::to_value(&msg).unwrap();
268 assert_eq!(json["role"], "assistant");
269 assert_eq!(json["content"], "The answer is 42");
270 assert!(json.get("tool_calls").is_none());
272 }
273
274 #[test]
275 fn test_message_assistant_with_tool_calls_format() {
276 let msg = Message::assistant_with_tool_calls(vec![ToolCall {
277 id: "call_abc".into(),
278 name: "search".into(),
279 arguments: serde_json::json!({"query": "rust"}),
280 }]);
281 let json = serde_json::to_value(&msg).unwrap();
282 assert_eq!(json["role"], "assistant");
283 assert!(json["content"].is_null());
284 assert_eq!(json["tool_calls"][0]["id"], "call_abc");
285 assert_eq!(json["tool_calls"][0]["name"], "search");
286 }
287
288 #[test]
289 fn test_message_tool_result_openai_format() {
290 let msg = Message::tool_result("call_abc", "Search results here");
291 let json = serde_json::to_value(&msg).unwrap();
292 assert_eq!(json["role"], "tool"); assert_eq!(json["tool_call_id"], "call_abc");
294 assert_eq!(json["content"], "Search results here");
295 }
296
297 #[test]
298 fn test_message_serde_roundtrip_all_variants() {
299 let messages = [
300 Message::system("Be helpful"),
301 Message::user("Hi"),
302 Message::assistant("Hello!"),
303 Message::assistant_with_tool_calls(vec![ToolCall {
304 id: "c1".into(),
305 name: "read".into(),
306 arguments: serde_json::json!({}),
307 }]),
308 Message::tool_result("c1", "file contents"),
309 ];
310
311 for msg in &messages {
312 let json = serde_json::to_string(msg).unwrap();
313 let deserialized: Message = serde_json::from_str(&json).unwrap();
314 let json2 = serde_json::to_string(&deserialized).unwrap();
316 assert_eq!(json, json2);
317 }
318 }
319
320 #[test]
321 fn test_message_deserialize_from_openai_response() {
322 let openai_json = r#"{"role": "assistant", "content": "Hello!", "tool_calls": []}"#;
324 let msg: Message = serde_json::from_str(openai_json).unwrap();
325 assert!(matches!(msg, Message::Assistant { content: Some(c), .. } if c == "Hello!"));
326 }
327
328 #[test]
329 fn test_message_deserialize_assistant_without_tool_calls() {
330 let openai_json = r#"{"role": "assistant", "content": "Hello!"}"#;
332 let msg: Message = serde_json::from_str(openai_json).unwrap();
333 match msg {
334 Message::Assistant {
335 content,
336 tool_calls,
337 } => {
338 assert_eq!(content, Some("Hello!".into()));
339 assert!(tool_calls.is_empty()); }
341 _ => panic!("Expected Assistant"),
342 }
343 }
344
345 #[test]
348 fn test_message_convenience_constructors() {
349 assert!(matches!(Message::system("x"), Message::System { content } if content == "x"));
350 assert!(matches!(Message::user("y"), Message::User { content } if content == "y"));
351 assert!(
352 matches!(Message::assistant("z"), Message::Assistant { content: Some(c), tool_calls } if c == "z" && tool_calls.is_empty())
353 );
354 assert!(
355 matches!(Message::assistant_with_tool_calls(vec![]), Message::Assistant { content: None, tool_calls } if tool_calls.is_empty())
356 );
357 assert!(
358 matches!(Message::tool_result("id", "res"), Message::ToolResult { tool_call_id, content } if tool_call_id == "id" && content == "res")
359 );
360 }
361
362 #[test]
365 fn test_tool_definition_from_tool() {
366 use crate::error::PulseHiveError;
367 use crate::tool::{ToolContext, ToolResult};
368
369 struct MockTool;
370
371 #[async_trait]
372 impl Tool for MockTool {
373 fn name(&self) -> &str {
374 "mock_tool"
375 }
376 fn description(&self) -> &str {
377 "A mock tool for testing"
378 }
379 fn parameters(&self) -> Value {
380 serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}})
381 }
382 async fn execute(
383 &self,
384 _params: Value,
385 _ctx: &ToolContext,
386 ) -> std::result::Result<ToolResult, PulseHiveError> {
387 Ok(ToolResult::text("ok"))
388 }
389 }
390
391 let def = ToolDefinition::from_tool(&MockTool);
392 assert_eq!(def.name, "mock_tool");
393 assert_eq!(def.description, "A mock tool for testing");
394 assert_eq!(def.parameters["type"], "object");
395 }
396
397 #[test]
398 fn test_multi_turn_conversation_serialization() {
399 let conversation = [
400 Message::system("You are a code assistant."),
401 Message::user("Read the config file."),
402 Message::assistant_with_tool_calls(vec![ToolCall {
403 id: "call_1".into(),
404 name: "read_file".into(),
405 arguments: serde_json::json!({"path": "config.toml"}),
406 }]),
407 Message::tool_result("call_1", "[package]\nname = \"test\""),
408 Message::assistant("The config file defines a package named 'test'."),
409 ];
410
411 for msg in &conversation {
413 let json = serde_json::to_value(msg).unwrap();
414 assert!(json.get("role").is_some(), "Missing role field");
415 }
416 assert_eq!(conversation.len(), 5);
417 }
418
419 #[test]
422 fn test_tool_definition_construction() {
423 let tool = ToolDefinition {
424 name: "search".into(),
425 description: "Search the codebase".into(),
426 parameters: serde_json::json!({
427 "type": "object",
428 "properties": {
429 "query": {"type": "string"}
430 },
431 "required": ["query"]
432 }),
433 };
434 assert_eq!(tool.name, "search");
435 }
436
437 #[test]
438 fn test_token_usage_default() {
439 let usage = TokenUsage::default();
440 assert_eq!(usage.input_tokens, 0);
441 assert_eq!(usage.output_tokens, 0);
442 }
443
444 #[test]
445 fn test_llm_chunk_variants() {
446 let text = LlmChunk::Text("hello".into());
447 assert!(matches!(text, LlmChunk::Text(s) if s == "hello"));
448
449 let start = LlmChunk::ToolCallStart {
450 id: "1".into(),
451 name: "search".into(),
452 };
453 assert!(matches!(start, LlmChunk::ToolCallStart { .. }));
454
455 let delta = LlmChunk::ToolCallDelta {
456 id: "1".into(),
457 arguments_delta: "{\"q".into(),
458 };
459 assert!(matches!(delta, LlmChunk::ToolCallDelta { .. }));
460
461 let done = LlmChunk::Done;
462 assert!(matches!(done, LlmChunk::Done));
463 }
464}