Skip to main content

bamboo_infrastructure/llm/
models.rs

1//! LLM API request and response models.
2//!
3//! This module defines the data structures used to communicate with
4//! various LLM providers (OpenAI, Anthropic, etc.) following OpenAI's API format.
5//!
6//! # Key Types
7//!
8//! ## Request Types
9//! - [`ChatCompletionRequest`] - Main request structure
10//! - [`ChatMessage`] - Message in conversation
11//!
12//! ## Content Types
13//! - [`Role`] - Message role (system, user, assistant, tool)
14//! - [`Content`] - Message content (text or parts)
15//! - [`ContentPart`] - Content part (text or image)
16//!
17//! ## Tool Types
18//! - [`Tool`] - Tool definition
19//! - [`ToolChoice`] - Tool selection strategy
20//! - [`ToolCall`] - Tool invocation
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use bamboo_agent::agent::llm::models::*;
26//!
27//! let request = ChatCompletionRequest {
28//!     model: "gpt-4o-mini".to_string(),
29//!     messages: vec![
30//!         ChatMessage {
31//!             role: Role::User,
32//!             content: Content::Text("Hello".to_string()),
33//!             tool_calls: None,
34//!             tool_call_id: None,
35//!         }
36//!     ],
37//!     tools: None,
38//!     tool_choice: None,
39//!     stream: Some(true),
40//!     stream_options: Some(StreamOptions { include_usage: true }),
41//!     parameters: HashMap::new(),
42//! };
43//! ```
44
45use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
46use std::collections::HashMap;
47
48// ========== Core Request Body ==========
49
50/// Chat completion request to LLM API.
51///
52/// Main request structure sent to LLM providers to generate
53/// chat completions with optional tool calling support.
54///
55/// # Fields
56///
57/// * `model` - Model identifier (e.g., "gpt-4o-mini", "claude-3-opus")
58/// * `messages` - Conversation history
59/// * `tools` - Available tools for the model
60/// * `tool_choice` - Tool selection strategy
61/// * `stream` - Whether to stream the response
62/// * `stream_options` - Streaming options
63/// * `parameters` - Additional model parameters (temperature, etc.)
64///
65/// # Example
66///
67/// ```rust,ignore
68/// let request = ChatCompletionRequest {
69///     model: "gpt-4o-mini".to_string(),
70///     messages: vec![
71///         ChatMessage::user("What is Rust?"),
72///     ],
73///     stream: Some(true),
74///     ..Default::default()
75/// };
76/// ```
77#[derive(Debug, Serialize, Deserialize, Clone, Default)]
78pub struct ChatCompletionRequest {
79    /// The model to use for the completion.
80    pub model: String,
81    /// A list of messages comprising the conversation so far.
82    pub messages: Vec<ChatMessage>,
83    /// A list of tools the model may call.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub tools: Option<Vec<Tool>>,
86    /// Controls which function is called by the model.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub tool_choice: Option<ToolChoice>,
89    /// Whether to stream the response.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub stream: Option<bool>,
92    /// Options for streaming response. Set `include_usage: true` to receive usage information.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub stream_options: Option<StreamOptions>,
95    /// Additional parameters like temperature, top_p, etc.
96    #[serde(flatten)]
97    pub parameters: HashMap<String, serde_json::Value>,
98}
99
100/// Options for streaming responses.
101///
102/// # Fields
103///
104/// * `include_usage` - Include token usage in final chunk
105#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
106pub struct StreamOptions {
107    /// If set to true, the streaming response will include a `usage` field in the final chunk.
108    pub include_usage: bool,
109}
110
111// ========== Message and Content Structures ==========
112
113/// A message in the conversation.
114///
115/// Represents one turn in the conversation with role and content.
116///
117/// # Fields
118///
119/// * `role` - Message author role
120/// * `content` - Message contents
121/// * `tool_calls` - Tool calls (for assistant messages)
122/// * `tool_call_id` - Tool call ID (for tool result messages)
123///
124/// # Example
125///
126/// ```rust,ignore
127/// let user_msg = ChatMessage {
128///     role: Role::User,
129///     content: Content::Text("Hello".to_string()),
130///     tool_calls: None,
131///     tool_call_id: None,
132/// };
133///
134/// let tool_result = ChatMessage {
135///     role: Role::Tool,
136///     content: Content::Text("Result".to_string()),
137///     tool_calls: None,
138///     tool_call_id: Some("call-123".to_string()),
139/// };
140/// ```
141#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
142pub struct ChatMessage {
143    /// The role of the message author.
144    pub role: Role,
145    /// The contents of the message.
146    #[serde(deserialize_with = "deserialize_content")]
147    pub content: Content,
148    /// Optional Responses-style assistant phase (`commentary` / `final_answer`).
149    #[serde(default, skip_serializing_if = "Option::is_none")]
150    pub phase: Option<String>,
151    /// The tool calls generated by the model, if any.
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub tool_calls: Option<Vec<ToolCall>>,
154    /// The ID of the tool call this message is a response to.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub tool_call_id: Option<String>,
157}
158
159fn deserialize_content<'de, D>(deserializer: D) -> Result<Content, D::Error>
160where
161    D: Deserializer<'de>,
162{
163    let value = serde_json::Value::deserialize(deserializer)?;
164    if value.is_null() {
165        // OpenAI-compatible clients may send `assistant.content = null` for tool-call turns.
166        return Ok(Content::Text(String::new()));
167    }
168    serde_json::from_value(value).map_err(D::Error::custom)
169}
170
171/// Role of a message author.
172///
173/// # Variants
174///
175/// * `System` - System instructions
176/// * `User` - User input
177/// * `Assistant` - AI response
178/// * `Tool` - Tool execution result
179#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
180#[serde(rename_all = "lowercase")]
181pub enum Role {
182    /// System instructions or prompts
183    #[serde(alias = "developer")]
184    System,
185    /// User input message
186    User,
187    /// AI assistant response
188    Assistant,
189    /// Tool execution result
190    Tool,
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn chat_message_accepts_null_content_as_empty_text() {
199        let value = serde_json::json!({
200            "role": "assistant",
201            "content": null
202        });
203
204        let msg: ChatMessage = serde_json::from_value(value).expect("should deserialize");
205        assert_eq!(msg.role, Role::Assistant);
206        assert_eq!(msg.content, Content::Text(String::new()));
207    }
208
209    #[test]
210    fn role_accepts_developer_alias() {
211        let value = serde_json::json!({
212            "role": "developer",
213            "content": "You are a helpful assistant."
214        });
215
216        let msg: ChatMessage = serde_json::from_value(value).expect("should deserialize");
217        assert_eq!(msg.role, Role::System);
218        assert_eq!(
219            msg.content,
220            Content::Text("You are a helpful assistant.".to_string())
221        );
222    }
223}
224
225/// Message content.
226///
227/// Can be either plain text or a list of content parts
228/// (for multimodal messages with text and images).
229///
230/// # Variants
231///
232/// * `Text(String)` - Simple text content
233/// * `Parts(Vec<ContentPart>)` - Multiple content parts
234///
235/// # Example
236///
237/// ```rust,ignore
238/// // Simple text
239/// let text = Content::Text("Hello".to_string());
240///
241/// // Multimodal
242/// let parts = Content::Parts(vec![
243///     ContentPart::Text { text: "What's in this image?".to_string() },
244///     ContentPart::ImageUrl { image_url: ImageUrl { url: "...".to_string(), detail: None } },
245/// ]);
246/// ```
247#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
248#[serde(untagged)]
249pub enum Content {
250    /// A single string of text content.
251    Text(String),
252    /// A list of content parts, for complex messages (e.g., with images).
253    Parts(Vec<ContentPart>),
254}
255
256/// A part of message content.
257///
258/// For multimodal messages, content can be text or image URLs.
259///
260/// # Variants
261///
262/// * `Text` - Text content
263/// * `ImageUrl` - Image URL reference
264///
265/// # Example
266///
267/// ```rust,ignore
268/// let text_part = ContentPart::Text {
269///     text: "Describe this image".to_string()
270/// };
271///
272/// let image_part = ContentPart::ImageUrl {
273///     image_url: ImageUrl {
274///         url: "https://example.com/image.png".to_string(),
275///         detail: Some("high".to_string()),
276///     }
277/// };
278/// ```
279#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
280#[serde(tag = "type", rename_all = "snake_case")]
281pub enum ContentPart {
282    /// Text content part
283    Text { text: String },
284    /// Image URL content part
285    ImageUrl { image_url: ImageUrl },
286}
287
288/// Image URL reference.
289///
290/// # Fields
291///
292/// * `url` - Image URL or base64 data URI
293/// * `detail` - Detail level ("low", "high", "auto")
294#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
295pub struct ImageUrl {
296    /// The URL of the image.
297    pub url: String,
298    /// The level of detail to use for the image.
299    #[serde(skip_serializing_if = "Option::is_none")]
300    pub detail: Option<String>,
301}
302
303// ── Conversions between ContentPart (LLM layer) and MessagePart (domain) ──
304
305impl From<bamboo_domain::MessagePart> for ContentPart {
306    fn from(part: bamboo_domain::MessagePart) -> Self {
307        match part {
308            bamboo_domain::MessagePart::Text { text } => ContentPart::Text { text },
309            bamboo_domain::MessagePart::ImageUrl { image_url: url_ref } => ContentPart::ImageUrl {
310                image_url: ImageUrl {
311                    url: url_ref.url,
312                    detail: url_ref.detail,
313                },
314            },
315        }
316    }
317}
318
319impl From<ContentPart> for bamboo_domain::MessagePart {
320    fn from(part: ContentPart) -> Self {
321        match part {
322            ContentPart::Text { text } => bamboo_domain::MessagePart::Text { text },
323            ContentPart::ImageUrl { image_url } => bamboo_domain::MessagePart::ImageUrl {
324                image_url: bamboo_domain::ImageUrlRef {
325                    url: image_url.url,
326                    detail: image_url.detail,
327                },
328            },
329        }
330    }
331}
332
333impl From<ImageUrl> for bamboo_domain::ImageUrlRef {
334    fn from(url: ImageUrl) -> Self {
335        bamboo_domain::ImageUrlRef {
336            url: url.url,
337            detail: url.detail,
338        }
339    }
340}
341
342impl From<bamboo_domain::ImageUrlRef> for ImageUrl {
343    fn from(url: bamboo_domain::ImageUrlRef) -> Self {
344        ImageUrl {
345            url: url.url,
346            detail: url.detail,
347        }
348    }
349}
350
351// ========== Tool-Related Structures ==========
352
353/// Tool definition for LLM function calling.
354///
355/// Defines a tool that the model can call during generation.
356///
357/// # Fields
358///
359/// * `tool_type` - Tool type (always "function")
360/// * `function` - Function definition
361///
362/// # Example
363///
364/// ```rust,ignore
365/// let tool = Tool {
366///     tool_type: "function".to_string(),
367///     function: FunctionDefinition {
368///         name: "read_file".to_string(),
369///         description: Some("Read file contents".to_string()),
370///         parameters: json!({
371///             "type": "object",
372///             "properties": {
373///                 "path": {"type": "string"}
374///             }
375///         }),
376///     },
377/// };
378/// ```
379#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
380pub struct Tool {
381    /// Tool type (always "function")
382    #[serde(rename = "type")]
383    pub tool_type: String,
384    /// Function definition
385    pub function: FunctionDefinition,
386}
387
388/// Function definition for tool schema.
389///
390/// # Fields
391///
392/// * `name` - Function name
393/// * `description` - Function description
394/// * `parameters` - JSON Schema for parameters
395#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
396pub struct FunctionDefinition {
397    /// Function name
398    pub name: String,
399    /// Function description for the model
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub description: Option<String>,
402    /// JSON Schema for function parameters
403    pub parameters: serde_json::Value, // JSON Schema
404}
405
406/// Tool selection strategy.
407///
408/// Controls which tool (if any) the model should call.
409///
410/// # Variants
411///
412/// * `String(String)` - "none", "auto", or "required"
413/// * `Object` - Specific function to call
414///
415/// # Example
416///
417/// ```rust,ignore
418/// // No tools
419/// let none = ToolChoice::String("none".to_string());
420///
421/// // Automatic selection
422/// let auto = ToolChoice::String("auto".to_string());
423///
424/// // Force specific tool
425/// let specific = ToolChoice::Object {
426///     tool_type: "function".to_string(),
427///     function: FunctionChoice {
428///         name: "read_file".to_string(),
429///     },
430/// };
431/// ```
432#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
433#[serde(untagged)]
434pub enum ToolChoice {
435    /// Tool selection mode: "none", "auto", or "required"
436    String(String),
437    /// Force specific function call
438    Object {
439        /// Tool type (always "function")
440        #[serde(rename = "type")]
441        tool_type: String,
442        /// Function to call
443        function: FunctionChoice,
444    },
445}
446
447/// Specific function choice for tool_choice.
448#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
449pub struct FunctionChoice {
450    /// Function name to call
451    pub name: String,
452}
453
454/// Tool call from the model.
455///
456/// Represents a tool invocation requested by the LLM.
457///
458/// # Fields
459///
460/// * `id` - Unique call identifier
461/// * `tool_type` - Tool type (always "function")
462/// * `function` - Function call details
463#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
464pub struct ToolCall {
465    /// Unique tool call identifier
466    pub id: String,
467    /// Tool type (always "function")
468    #[serde(rename = "type")]
469    pub tool_type: String,
470    /// Function call details
471    pub function: FunctionCall,
472}
473
474/// Function call details.
475///
476/// # Fields
477///
478/// * `name` - Function name
479/// * `arguments` - JSON-encoded arguments
480#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
481pub struct FunctionCall {
482    /// Function name to invoke
483    pub name: String,
484    /// JSON-encoded function arguments
485    pub arguments: String, // JSON string
486}
487
488// ========== Response Structures ==========
489
490/// Chat completion response from LLM API.
491///
492/// # Fields
493///
494/// * `id` - Response ID
495/// * `object` - Object type
496/// * `created` - Creation timestamp
497/// * `model` - Model used
498/// * `choices` - Completion choices
499/// * `usage` - Token usage statistics
500/// * `system_fingerprint` - System fingerprint
501#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
502pub struct ChatCompletionResponse {
503    /// Response identifier
504    pub id: String,
505    /// Object type (e.g., "chat.completion")
506    #[serde(default)]
507    pub object: Option<String>,
508    /// Unix timestamp when response was created
509    #[serde(default)]
510    pub created: Option<u64>,
511    /// Model name used for generation
512    #[serde(default)]
513    pub model: Option<String>,
514    #[serde(default)]
515    pub choices: Vec<ResponseChoice>,
516    #[serde(default)]
517    pub usage: Option<Usage>,
518    #[serde(skip_serializing_if = "Option::is_none")]
519    pub system_fingerprint: Option<String>,
520}
521
522/// A completion choice in the response.
523///
524/// # Fields
525///
526/// * `index` - Choice index
527/// * `message` - Completion message
528/// * `finish_reason` - Reason for finishing
529#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
530pub struct ResponseChoice {
531    #[serde(default)]
532    pub index: u32,
533    pub message: ChatMessage,
534    #[serde(default)]
535    pub finish_reason: Option<String>,
536}
537
538/// Token usage statistics.
539///
540/// # Fields
541///
542/// * `prompt_tokens` - Tokens in prompt
543/// * `completion_tokens` - Tokens in completion
544/// * `total_tokens` - Total tokens used
545#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
546pub struct Usage {
547    /// Tokens in the prompt
548    #[serde(default)]
549    pub prompt_tokens: u32,
550    /// Tokens in the completion
551    #[serde(default)]
552    pub completion_tokens: u32,
553    /// Total tokens used
554    #[serde(default)]
555    pub total_tokens: u32,
556}
557
558// For Stream Responses
559
560/// Streaming chat completion chunk.
561///
562/// Each chunk contains incremental updates during streaming.
563///
564/// # Fields
565///
566/// * `id` - Response ID
567/// * `object` - Object type
568/// * `created` - Creation timestamp
569/// * `model` - Model used
570/// * `choices` - Stream choices with deltas
571/// * `usage` - Usage statistics (final chunk only)
572#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
573pub struct ChatCompletionStreamChunk {
574    pub id: String,
575    #[serde(default)]
576    pub object: Option<String>,
577    pub created: u64,
578    #[serde(default)]
579    pub model: Option<String>,
580    pub choices: Vec<StreamChoice>,
581    /// Usage statistics for the request. Only present if `stream_options.include_usage` was set to true.
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub usage: Option<Usage>,
584}
585
586/// A choice in a streaming chunk.
587///
588/// # Fields
589///
590/// * `index` - Choice index
591/// * `delta` - Content delta
592/// * `finish_reason` - Reason for finishing (if complete)
593#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
594pub struct StreamChoice {
595    /// Choice index
596    pub index: u32,
597    /// Content delta
598    pub delta: StreamDelta,
599    /// Reason for finishing
600    #[serde(skip_serializing_if = "Option::is_none")]
601    pub finish_reason: Option<String>,
602}
603
604// Streaming-specific tool call structures
605// These allow partial data since the API sends tool calls incrementally across multiple chunks
606
607/// Streaming tool call fragment.
608///
609/// During streaming, tool calls are sent incrementally across multiple chunks.
610/// The `index` field identifies which tool call each fragment belongs to.
611///
612/// # Fields
613///
614/// * `index` - Tool call index for reassembly
615/// * `id` - Tool call ID (first chunk only)
616/// * `tool_type` - Tool type (first chunk only)
617/// * `function` - Function call fragment
618#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
619pub struct StreamToolCall {
620    /// Index to identify which tool call this fragment belongs to
621    pub index: u32,
622    /// Tool call ID (only present in the first chunk)
623    #[serde(skip_serializing_if = "Option::is_none")]
624    pub id: Option<String>,
625    /// Tool type (only present in the first chunk)
626    #[serde(rename = "type")]
627    #[serde(skip_serializing_if = "Option::is_none")]
628    pub tool_type: Option<String>,
629    /// Function call data (may be partial)
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub function: Option<StreamFunctionCall>,
632}
633
634/// Streaming function call fragment.
635///
636/// # Fields
637///
638/// * `name` - Function name (first chunk only)
639/// * `arguments` - Arguments fragment (sent incrementally)
640#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
641pub struct StreamFunctionCall {
642    /// Function name (only present in the first chunk)
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub name: Option<String>,
645    /// Arguments (sent incrementally across chunks)
646    #[serde(skip_serializing_if = "Option::is_none")]
647    pub arguments: Option<String>,
648}
649
650/// Content delta in a streaming chunk.
651///
652/// Contains incremental content updates during streaming.
653///
654/// # Fields
655///
656/// * `role` - Message role (first chunk only)
657/// * `content` - Text content delta
658/// * `tool_calls` - Tool call fragments
659#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
660pub struct StreamDelta {
661    /// Message role (only present in first chunk)
662    #[serde(skip_serializing_if = "Option::is_none")]
663    pub role: Option<Role>,
664    /// Text content delta
665    #[serde(skip_serializing_if = "Option::is_none")]
666    pub content: Option<String>,
667    /// Tool call fragments
668    #[serde(skip_serializing_if = "Option::is_none")]
669    pub tool_calls: Option<Vec<StreamToolCall>>,
670}