Skip to main content

neuron_turn/
types.rs

1//! Internal types for the neuron-turn ReAct loop.
2//!
3//! These are the internal lingua franca — not layer0 types, not
4//! provider-specific types. Providers convert to/from these.
5
6use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8
9/// Role in a conversation.
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13    /// System message (instructions).
14    System,
15    /// User message.
16    User,
17    /// Assistant (model) message.
18    Assistant,
19}
20
21/// Source for image content.
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum ImageSource {
25    /// Base64-encoded image data.
26    Base64 {
27        /// The base64-encoded data.
28        data: String,
29    },
30    /// URL pointing to an image.
31    Url {
32        /// The image URL.
33        url: String,
34    },
35}
36
37/// A single content part within a message.
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41    /// Plain text.
42    Text {
43        /// The text content.
44        text: String,
45    },
46    /// A tool use request from the model.
47    ToolUse {
48        /// Unique identifier for this tool use.
49        id: String,
50        /// Name of the tool to invoke.
51        name: String,
52        /// Tool input parameters.
53        input: serde_json::Value,
54    },
55    /// Result from a tool execution.
56    ToolResult {
57        /// The tool_use id this result corresponds to.
58        tool_use_id: String,
59        /// The result content.
60        content: String,
61        /// Whether the tool execution errored.
62        is_error: bool,
63    },
64    /// Image content.
65    Image {
66        /// The image source.
67        source: ImageSource,
68        /// MIME type of the image.
69        media_type: String,
70    },
71}
72
73/// A message in the provider conversation.
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct ProviderMessage {
76    /// The role of the message author.
77    pub role: Role,
78    /// Content parts of the message.
79    pub content: Vec<ContentPart>,
80}
81
82/// JSON Schema description of a tool for the provider.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ToolSchema {
85    /// Tool name.
86    pub name: String,
87    /// Human-readable description.
88    pub description: String,
89    /// JSON Schema for the tool's input.
90    pub input_schema: serde_json::Value,
91}
92
93/// Request sent to a provider.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ProviderRequest {
96    /// Model to use (None = provider default).
97    pub model: Option<String>,
98    /// Conversation messages.
99    pub messages: Vec<ProviderMessage>,
100    /// Available tools.
101    pub tools: Vec<ToolSchema>,
102    /// Maximum output tokens.
103    pub max_tokens: Option<u32>,
104    /// Sampling temperature.
105    pub temperature: Option<f64>,
106    /// System prompt.
107    pub system: Option<String>,
108    /// Provider-specific config passthrough.
109    #[serde(default)]
110    pub extra: serde_json::Value,
111}
112
113/// Why the provider stopped generating.
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum StopReason {
117    /// Model produced a final response.
118    EndTurn,
119    /// Model wants to use a tool.
120    ToolUse,
121    /// Hit the max_tokens limit.
122    MaxTokens,
123    /// Content was filtered by safety.
124    ContentFilter,
125}
126
127/// Token usage from a single provider call.
128#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
129pub struct TokenUsage {
130    /// Input tokens consumed.
131    pub input_tokens: u64,
132    /// Output tokens generated.
133    pub output_tokens: u64,
134    /// Tokens read from cache (if supported).
135    pub cache_read_tokens: Option<u64>,
136    /// Tokens written to cache (if supported).
137    pub cache_creation_tokens: Option<u64>,
138}
139
140/// Response from a provider.
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ProviderResponse {
143    /// Response content parts.
144    pub content: Vec<ContentPart>,
145    /// Why the provider stopped.
146    pub stop_reason: StopReason,
147    /// Token usage.
148    pub usage: TokenUsage,
149    /// Actual model used.
150    pub model: String,
151    /// Cost calculated by the provider (None if unknown).
152    pub cost: Option<Decimal>,
153    /// Whether the provider truncated input (telemetry only).
154    pub truncated: Option<bool>,
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use serde_json::json;
161
162    #[test]
163    fn role_serde_roundtrip() {
164        for role in [Role::System, Role::User, Role::Assistant] {
165            let json = serde_json::to_string(&role).unwrap();
166            let back: Role = serde_json::from_str(&json).unwrap();
167            assert_eq!(role, back);
168        }
169    }
170
171    #[test]
172    fn content_part_text_roundtrip() {
173        let part = ContentPart::Text {
174            text: "hello".into(),
175        };
176        let json = serde_json::to_value(&part).unwrap();
177        assert_eq!(json["type"], "text");
178        let back: ContentPart = serde_json::from_value(json).unwrap();
179        assert_eq!(part, back);
180    }
181
182    #[test]
183    fn content_part_tool_use_roundtrip() {
184        let part = ContentPart::ToolUse {
185            id: "tu_1".into(),
186            name: "bash".into(),
187            input: json!({"command": "ls"}),
188        };
189        let json = serde_json::to_value(&part).unwrap();
190        assert_eq!(json["type"], "tool_use");
191        let back: ContentPart = serde_json::from_value(json).unwrap();
192        assert_eq!(part, back);
193    }
194
195    #[test]
196    fn content_part_tool_result_roundtrip() {
197        let part = ContentPart::ToolResult {
198            tool_use_id: "tu_1".into(),
199            content: "file.txt".into(),
200            is_error: false,
201        };
202        let json = serde_json::to_value(&part).unwrap();
203        assert_eq!(json["type"], "tool_result");
204        let back: ContentPart = serde_json::from_value(json).unwrap();
205        assert_eq!(part, back);
206    }
207
208    #[test]
209    fn content_part_image_roundtrip() {
210        let part = ContentPart::Image {
211            source: ImageSource::Url {
212                url: "https://example.com/img.png".into(),
213            },
214            media_type: "image/png".into(),
215        };
216        let json = serde_json::to_value(&part).unwrap();
217        assert_eq!(json["type"], "image");
218        let back: ContentPart = serde_json::from_value(json).unwrap();
219        assert_eq!(part, back);
220    }
221
222    #[test]
223    fn stop_reason_roundtrip() {
224        for reason in [
225            StopReason::EndTurn,
226            StopReason::ToolUse,
227            StopReason::MaxTokens,
228            StopReason::ContentFilter,
229        ] {
230            let json = serde_json::to_string(&reason).unwrap();
231            let back: StopReason = serde_json::from_str(&json).unwrap();
232            assert_eq!(reason, back);
233        }
234    }
235
236    #[test]
237    fn provider_message_roundtrip() {
238        let msg = ProviderMessage {
239            role: Role::User,
240            content: vec![ContentPart::Text {
241                text: "hello".into(),
242            }],
243        };
244        let json = serde_json::to_value(&msg).unwrap();
245        let back: ProviderMessage = serde_json::from_value(json).unwrap();
246        assert_eq!(msg, back);
247    }
248
249    #[test]
250    fn token_usage_default() {
251        let usage = TokenUsage::default();
252        assert_eq!(usage.input_tokens, 0);
253        assert_eq!(usage.output_tokens, 0);
254        assert!(usage.cache_read_tokens.is_none());
255    }
256
257    #[test]
258    fn token_usage_serde_roundtrip() {
259        let usage = TokenUsage {
260            input_tokens: 100,
261            output_tokens: 50,
262            cache_read_tokens: Some(10),
263            cache_creation_tokens: Some(5),
264        };
265        let json = serde_json::to_value(&usage).unwrap();
266        let back: TokenUsage = serde_json::from_value(json).unwrap();
267        assert_eq!(usage, back);
268    }
269
270    #[test]
271    fn image_source_base64_roundtrip() {
272        let source = ImageSource::Base64 {
273            data: "aGVsbG8=".into(),
274        };
275        let json = serde_json::to_value(&source).unwrap();
276        assert_eq!(json["type"], "base64");
277        let back: ImageSource = serde_json::from_value(json).unwrap();
278        assert_eq!(source, back);
279    }
280
281    #[test]
282    fn image_source_url_roundtrip() {
283        let source = ImageSource::Url {
284            url: "https://example.com/img.png".into(),
285        };
286        let json = serde_json::to_value(&source).unwrap();
287        assert_eq!(json["type"], "url");
288        let back: ImageSource = serde_json::from_value(json).unwrap();
289        assert_eq!(source, back);
290    }
291
292    #[test]
293    fn provider_request_serde_roundtrip() {
294        let request = ProviderRequest {
295            model: Some("test-model".into()),
296            messages: vec![ProviderMessage {
297                role: Role::User,
298                content: vec![ContentPart::Text {
299                    text: "hello".into(),
300                }],
301            }],
302            tools: vec![ToolSchema {
303                name: "bash".into(),
304                description: "Run a command".into(),
305                input_schema: json!({"type": "object"}),
306            }],
307            max_tokens: Some(1024),
308            temperature: Some(0.7),
309            system: Some("Be helpful".into()),
310            extra: json!({"key": "value"}),
311        };
312        let json = serde_json::to_value(&request).unwrap();
313        let back: ProviderRequest = serde_json::from_value(json).unwrap();
314        assert_eq!(back.model, Some("test-model".into()));
315        assert_eq!(back.messages.len(), 1);
316        assert_eq!(back.tools.len(), 1);
317        assert_eq!(back.max_tokens, Some(1024));
318        assert_eq!(back.system, Some("Be helpful".into()));
319    }
320
321    #[test]
322    fn provider_response_serde_roundtrip() {
323        let response = ProviderResponse {
324            content: vec![ContentPart::Text {
325                text: "hello".into(),
326            }],
327            stop_reason: StopReason::EndTurn,
328            usage: TokenUsage {
329                input_tokens: 10,
330                output_tokens: 5,
331                cache_read_tokens: None,
332                cache_creation_tokens: None,
333            },
334            model: "test-model".into(),
335            cost: Some(rust_decimal::Decimal::new(1, 4)),
336            truncated: None,
337        };
338        let json = serde_json::to_value(&response).unwrap();
339        let back: ProviderResponse = serde_json::from_value(json).unwrap();
340        assert_eq!(back.model, "test-model");
341        assert_eq!(back.stop_reason, StopReason::EndTurn);
342        assert_eq!(back.content.len(), 1);
343    }
344
345    #[test]
346    fn content_part_image_base64_roundtrip() {
347        let part = ContentPart::Image {
348            source: ImageSource::Base64 {
349                data: "aGVsbG8=".into(),
350            },
351            media_type: "image/jpeg".into(),
352        };
353        let json = serde_json::to_value(&part).unwrap();
354        assert_eq!(json["type"], "image");
355        let back: ContentPart = serde_json::from_value(json).unwrap();
356        assert_eq!(part, back);
357    }
358
359    #[test]
360    fn provider_message_multi_content_roundtrip() {
361        let msg = ProviderMessage {
362            role: Role::Assistant,
363            content: vec![
364                ContentPart::Text {
365                    text: "Let me help.".into(),
366                },
367                ContentPart::ToolUse {
368                    id: "tu_1".into(),
369                    name: "bash".into(),
370                    input: json!({"cmd": "ls"}),
371                },
372            ],
373        };
374        let json = serde_json::to_value(&msg).unwrap();
375        let back: ProviderMessage = serde_json::from_value(json).unwrap();
376        assert_eq!(msg, back);
377    }
378
379    #[test]
380    fn tool_result_with_error_roundtrip() {
381        let part = ContentPart::ToolResult {
382            tool_use_id: "tu_1".into(),
383            content: "command failed".into(),
384            is_error: true,
385        };
386        let json = serde_json::to_value(&part).unwrap();
387        let back: ContentPart = serde_json::from_value(json).unwrap();
388        assert_eq!(part, back);
389    }
390}