Skip to main content

pulsehive_core/
llm.rs

1//! LLM provider abstraction and message types.
2//!
3//! [`LlmProvider`] is the trait that provider crates (`pulsehive-openai`, `pulsehive-anthropic`)
4//! implement. Products can also implement custom providers.
5//!
6//! Messages serialize to OpenAI's chat completions format:
7//! ```json
8//! {"role": "system", "content": "You are helpful"}
9//! {"role": "user", "content": "Hello"}
10//! {"role": "assistant", "content": null, "tool_calls": [...]}
11//! {"role": "tool", "tool_call_id": "call_1", "content": "result"}
12//! ```
13
14use std::pin::Pin;
15
16use async_trait::async_trait;
17use futures_core::Stream;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20
21use crate::error::Result;
22use crate::tool::Tool;
23
24/// LLM model selection and generation parameters.
25///
26/// The `provider` field routes to a named [`LlmProvider`] instance registered
27/// with the HiveMind builder (e.g., `"openai"`, `"anthropic"`).
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct LlmConfig {
30    /// Provider name — matches the key used in `HiveMind::builder().llm_provider(name, ...)`.
31    pub provider: String,
32    /// Model identifier (e.g., `"claude-sonnet-4-6"`, `"gpt-4"`, `"glm-5"`).
33    pub model: String,
34    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative).
35    pub temperature: f32,
36    /// Maximum tokens to generate.
37    pub max_tokens: u32,
38}
39
40impl LlmConfig {
41    /// Creates a config with sensible defaults (temperature 0.7, max_tokens 4096).
42    pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
43        Self {
44            provider: provider.into(),
45            model: model.into(),
46            temperature: 0.7,
47            max_tokens: 4096,
48        }
49    }
50}
51
52/// A message in a multi-turn conversation.
53///
54/// Serializes to OpenAI's chat completions format with `"role"` as the discriminator.
55/// This format is compatible with OpenAI, GLM, vLLM, Ollama, and other OpenAI-compatible APIs.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(tag = "role", rename_all = "snake_case")]
58pub enum Message {
59    /// System prompt that configures agent behavior.
60    /// Serializes as: `{"role": "system", "content": "..."}`
61    System { content: String },
62
63    /// User message (task description, follow-up, etc.).
64    /// Serializes as: `{"role": "user", "content": "..."}`
65    User { content: String },
66
67    /// Assistant response, optionally with tool calls.
68    /// Serializes as: `{"role": "assistant", "content": "...", "tool_calls": [...]}`
69    Assistant {
70        content: Option<String>,
71        #[serde(default, skip_serializing_if = "Vec::is_empty")]
72        tool_calls: Vec<ToolCall>,
73    },
74
75    /// Result of a tool execution, sent back to the LLM.
76    /// Serializes as: `{"role": "tool", "tool_call_id": "...", "content": "..."}`
77    #[serde(rename = "tool")]
78    ToolResult {
79        tool_call_id: String,
80        content: String,
81    },
82}
83
84impl Message {
85    /// Creates a system message.
86    pub fn system(content: impl Into<String>) -> Self {
87        Self::System {
88            content: content.into(),
89        }
90    }
91
92    /// Creates a user message.
93    pub fn user(content: impl Into<String>) -> Self {
94        Self::User {
95            content: content.into(),
96        }
97    }
98
99    /// Creates an assistant message with text content (no tool calls).
100    pub fn assistant(content: impl Into<String>) -> Self {
101        Self::Assistant {
102            content: Some(content.into()),
103            tool_calls: vec![],
104        }
105    }
106
107    /// Creates an assistant message with tool calls (no text content).
108    pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
109        Self::Assistant {
110            content: None,
111            tool_calls,
112        }
113    }
114
115    /// Creates a tool result message.
116    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
117        Self::ToolResult {
118            tool_call_id: tool_call_id.into(),
119            content: content.into(),
120        }
121    }
122}
123
124/// An LLM's request to invoke a tool.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ToolCall {
127    /// Unique ID for this tool call (used to match with ToolResult).
128    pub id: String,
129    /// Name of the tool to invoke.
130    pub name: String,
131    /// JSON arguments parsed from the LLM's output.
132    pub arguments: Value,
133}
134
135/// Tool schema sent to the LLM so it knows what tools are available.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ToolDefinition {
138    /// Tool name (must match `Tool::name()`).
139    pub name: String,
140    /// Description the LLM uses to decide when to invoke this tool.
141    pub description: String,
142    /// JSON Schema describing the tool's parameters.
143    pub parameters: Value,
144}
145
146impl ToolDefinition {
147    /// Creates a ToolDefinition from a Tool trait object.
148    ///
149    /// Extracts name, description, and parameters from the tool for sending to the LLM.
150    pub fn from_tool(tool: &dyn Tool) -> Self {
151        Self {
152            name: tool.name().to_string(),
153            description: tool.description().to_string(),
154            parameters: tool.parameters(),
155        }
156    }
157}
158
159/// Token usage statistics from an LLM call.
160#[derive(Debug, Clone, Default, Serialize, Deserialize)]
161pub struct TokenUsage {
162    /// Tokens consumed by the input (prompt + context).
163    pub input_tokens: u32,
164    /// Tokens generated in the output.
165    pub output_tokens: u32,
166}
167
168/// Complete response from a non-streaming LLM call.
169#[derive(Debug, Clone)]
170pub struct LlmResponse {
171    /// Text content of the response (`None` if only tool calls).
172    pub content: Option<String>,
173    /// Tool calls requested by the LLM (empty if text-only response).
174    pub tool_calls: Vec<ToolCall>,
175    /// Token usage statistics.
176    pub usage: TokenUsage,
177}
178
179/// A chunk from a streaming LLM response.
180#[derive(Debug, Clone)]
181pub enum LlmChunk {
182    /// A text token delta.
183    Text(String),
184    /// Start of a tool call (id and name known).
185    ToolCallStart { id: String, name: String },
186    /// Incremental arguments for an in-progress tool call.
187    ToolCallDelta { id: String, arguments_delta: String },
188    /// Stream is complete.
189    Done,
190}
191
192/// Trait for LLM provider implementations.
193///
194/// Provider crates (`pulsehive-openai`, `pulsehive-anthropic`) implement this trait.
195/// Products can also implement custom providers for self-hosted models.
196///
197/// Must be `Send + Sync` for use across Tokio tasks and object-safe for
198/// `Arc<dyn LlmProvider>` in HiveMind.
199#[async_trait]
200pub trait LlmProvider: Send + Sync {
201    /// Send a chat completion request and return the full response.
202    async fn chat(
203        &self,
204        messages: Vec<Message>,
205        tools: Vec<ToolDefinition>,
206        config: &LlmConfig,
207    ) -> Result<LlmResponse>;
208
209    /// Send a chat completion request and stream tokens.
210    async fn chat_stream(
211        &self,
212        messages: Vec<Message>,
213        tools: Vec<ToolDefinition>,
214        config: &LlmConfig,
215    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>>;
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_llm_provider_is_object_safe() {
224        fn _assert_object_safe(_: &dyn LlmProvider) {}
225    }
226
227    #[test]
228    fn test_llm_config_new_defaults() {
229        let config = LlmConfig::new("openai", "gpt-4");
230        assert_eq!(config.provider, "openai");
231        assert_eq!(config.model, "gpt-4");
232        assert!((config.temperature - 0.7).abs() < f32::EPSILON);
233        assert_eq!(config.max_tokens, 4096);
234    }
235
236    #[test]
237    fn test_llm_config_serialization() {
238        let config = LlmConfig::new("anthropic", "claude-sonnet-4-6");
239        let json = serde_json::to_string(&config).unwrap();
240        let deserialized: LlmConfig = serde_json::from_str(&json).unwrap();
241        assert_eq!(deserialized.provider, "anthropic");
242        assert_eq!(deserialized.model, "claude-sonnet-4-6");
243    }
244
245    // ── OpenAI format serialization tests ────────────────────────────
246
247    #[test]
248    fn test_message_system_openai_format() {
249        let msg = Message::system("You are helpful");
250        let json = serde_json::to_value(&msg).unwrap();
251        assert_eq!(json["role"], "system");
252        assert_eq!(json["content"], "You are helpful");
253        assert_eq!(json.as_object().unwrap().len(), 2); // only role + content
254    }
255
256    #[test]
257    fn test_message_user_openai_format() {
258        let msg = Message::user("Hello");
259        let json = serde_json::to_value(&msg).unwrap();
260        assert_eq!(json["role"], "user");
261        assert_eq!(json["content"], "Hello");
262    }
263
264    #[test]
265    fn test_message_assistant_text_only_format() {
266        let msg = Message::assistant("The answer is 42");
267        let json = serde_json::to_value(&msg).unwrap();
268        assert_eq!(json["role"], "assistant");
269        assert_eq!(json["content"], "The answer is 42");
270        // tool_calls should be absent (skip_serializing_if empty)
271        assert!(json.get("tool_calls").is_none());
272    }
273
274    #[test]
275    fn test_message_assistant_with_tool_calls_format() {
276        let msg = Message::assistant_with_tool_calls(vec![ToolCall {
277            id: "call_abc".into(),
278            name: "search".into(),
279            arguments: serde_json::json!({"query": "rust"}),
280        }]);
281        let json = serde_json::to_value(&msg).unwrap();
282        assert_eq!(json["role"], "assistant");
283        assert!(json["content"].is_null());
284        assert_eq!(json["tool_calls"][0]["id"], "call_abc");
285        assert_eq!(json["tool_calls"][0]["name"], "search");
286    }
287
288    #[test]
289    fn test_message_tool_result_openai_format() {
290        let msg = Message::tool_result("call_abc", "Search results here");
291        let json = serde_json::to_value(&msg).unwrap();
292        assert_eq!(json["role"], "tool"); // NOT "tool_result"
293        assert_eq!(json["tool_call_id"], "call_abc");
294        assert_eq!(json["content"], "Search results here");
295    }
296
297    #[test]
298    fn test_message_serde_roundtrip_all_variants() {
299        let messages = [
300            Message::system("Be helpful"),
301            Message::user("Hi"),
302            Message::assistant("Hello!"),
303            Message::assistant_with_tool_calls(vec![ToolCall {
304                id: "c1".into(),
305                name: "read".into(),
306                arguments: serde_json::json!({}),
307            }]),
308            Message::tool_result("c1", "file contents"),
309        ];
310
311        for msg in &messages {
312            let json = serde_json::to_string(msg).unwrap();
313            let deserialized: Message = serde_json::from_str(&json).unwrap();
314            // Verify roundtrip produces same JSON
315            let json2 = serde_json::to_string(&deserialized).unwrap();
316            assert_eq!(json, json2);
317        }
318    }
319
320    #[test]
321    fn test_message_deserialize_from_openai_response() {
322        // Simulate what we'd get back from OpenAI's API
323        let openai_json = r#"{"role": "assistant", "content": "Hello!", "tool_calls": []}"#;
324        let msg: Message = serde_json::from_str(openai_json).unwrap();
325        assert!(matches!(msg, Message::Assistant { content: Some(c), .. } if c == "Hello!"));
326    }
327
328    #[test]
329    fn test_message_deserialize_assistant_without_tool_calls() {
330        // OpenAI sometimes omits tool_calls entirely
331        let openai_json = r#"{"role": "assistant", "content": "Hello!"}"#;
332        let msg: Message = serde_json::from_str(openai_json).unwrap();
333        match msg {
334            Message::Assistant {
335                content,
336                tool_calls,
337            } => {
338                assert_eq!(content, Some("Hello!".into()));
339                assert!(tool_calls.is_empty()); // default
340            }
341            _ => panic!("Expected Assistant"),
342        }
343    }
344
345    // ── Convenience constructor tests ────────────────────────────────
346
347    #[test]
348    fn test_message_convenience_constructors() {
349        assert!(matches!(Message::system("x"), Message::System { content } if content == "x"));
350        assert!(matches!(Message::user("y"), Message::User { content } if content == "y"));
351        assert!(
352            matches!(Message::assistant("z"), Message::Assistant { content: Some(c), tool_calls } if c == "z" && tool_calls.is_empty())
353        );
354        assert!(
355            matches!(Message::assistant_with_tool_calls(vec![]), Message::Assistant { content: None, tool_calls } if tool_calls.is_empty())
356        );
357        assert!(
358            matches!(Message::tool_result("id", "res"), Message::ToolResult { tool_call_id, content } if tool_call_id == "id" && content == "res")
359        );
360    }
361
362    // ── ToolDefinition tests ─────────────────────────────────────────
363
364    #[test]
365    fn test_tool_definition_from_tool() {
366        use crate::error::PulseHiveError;
367        use crate::tool::{ToolContext, ToolResult};
368
369        struct MockTool;
370
371        #[async_trait]
372        impl Tool for MockTool {
373            fn name(&self) -> &str {
374                "mock_tool"
375            }
376            fn description(&self) -> &str {
377                "A mock tool for testing"
378            }
379            fn parameters(&self) -> Value {
380                serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}})
381            }
382            async fn execute(
383                &self,
384                _params: Value,
385                _ctx: &ToolContext,
386            ) -> std::result::Result<ToolResult, PulseHiveError> {
387                Ok(ToolResult::text("ok"))
388            }
389        }
390
391        let def = ToolDefinition::from_tool(&MockTool);
392        assert_eq!(def.name, "mock_tool");
393        assert_eq!(def.description, "A mock tool for testing");
394        assert_eq!(def.parameters["type"], "object");
395    }
396
397    #[test]
398    fn test_multi_turn_conversation_serialization() {
399        let conversation = [
400            Message::system("You are a code assistant."),
401            Message::user("Read the config file."),
402            Message::assistant_with_tool_calls(vec![ToolCall {
403                id: "call_1".into(),
404                name: "read_file".into(),
405                arguments: serde_json::json!({"path": "config.toml"}),
406            }]),
407            Message::tool_result("call_1", "[package]\nname = \"test\""),
408            Message::assistant("The config file defines a package named 'test'."),
409        ];
410
411        // Verify all serialize to valid JSON with role field
412        for msg in &conversation {
413            let json = serde_json::to_value(msg).unwrap();
414            assert!(json.get("role").is_some(), "Missing role field");
415        }
416        assert_eq!(conversation.len(), 5);
417    }
418
419    // ── Other type tests ─────────────────────────────────────────────
420
421    #[test]
422    fn test_tool_definition_construction() {
423        let tool = ToolDefinition {
424            name: "search".into(),
425            description: "Search the codebase".into(),
426            parameters: serde_json::json!({
427                "type": "object",
428                "properties": {
429                    "query": {"type": "string"}
430                },
431                "required": ["query"]
432            }),
433        };
434        assert_eq!(tool.name, "search");
435    }
436
437    #[test]
438    fn test_token_usage_default() {
439        let usage = TokenUsage::default();
440        assert_eq!(usage.input_tokens, 0);
441        assert_eq!(usage.output_tokens, 0);
442    }
443
444    #[test]
445    fn test_llm_chunk_variants() {
446        let text = LlmChunk::Text("hello".into());
447        assert!(matches!(text, LlmChunk::Text(s) if s == "hello"));
448
449        let start = LlmChunk::ToolCallStart {
450            id: "1".into(),
451            name: "search".into(),
452        };
453        assert!(matches!(start, LlmChunk::ToolCallStart { .. }));
454
455        let delta = LlmChunk::ToolCallDelta {
456            id: "1".into(),
457            arguments_delta: "{\"q".into(),
458        };
459        assert!(matches!(delta, LlmChunk::ToolCallDelta { .. }));
460
461        let done = LlmChunk::Done;
462        assert!(matches!(done, LlmChunk::Done));
463    }
464}