Skip to main content

iron_providers/
model.rs

1//! Semantic provider request/response models
2//!
3//! These types define the normalized boundary between iron-core and
4//! provider implementations. They are intentionally domain-oriented
5//! rather than mirroring any specific provider's wire format.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10/// Reserved internal tool name used by providers to normalize model-originated
11/// choice requests into first-class `ProviderEvent::ChoiceRequest` events.
12pub const CHOICE_REQUEST_TOOL_NAME: &str = "runtime.request_choice";
13
14/// A transcript of conversation messages
15#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
16pub struct Transcript {
17    /// Ordered conversation messages.
18    pub messages: Vec<Message>,
19}
20
21impl Transcript {
22    /// Create an empty transcript.
23    pub fn new() -> Self {
24        Self { messages: vec![] }
25    }
26
27    /// Create a transcript with the provided messages.
28    pub fn with_messages(messages: Vec<Message>) -> Self {
29        Self { messages }
30    }
31
32    /// Append a message to the transcript.
33    pub fn add_message(&mut self, message: Message) {
34        self.messages.push(message);
35    }
36
37    /// Return whether the transcript contains no messages.
38    pub fn is_empty(&self) -> bool {
39        self.messages.is_empty()
40    }
41}
42
43/// A message in the conversation transcript
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45#[serde(tag = "role", rename_all = "snake_case")]
46pub enum Message {
47    /// User message with text content
48    User { content: String },
49    /// Assistant message with text content
50    Assistant { content: String },
51    /// Assistant tool call (the model requesting to call a tool)
52    AssistantToolCall {
53        /// Stable tool call identifier.
54        call_id: String,
55        /// Requested tool name.
56        tool_name: String,
57        /// Parsed tool arguments.
58        arguments: Value,
59    },
60    /// Tool result message
61    Tool {
62        /// Stable tool call identifier.
63        call_id: String,
64        /// Tool name associated with the result.
65        tool_name: String,
66        /// Structured tool result.
67        result: Value,
68    },
69}
70
71impl Message {
72    /// Create a user message
73    pub fn user<S: Into<String>>(content: S) -> Self {
74        Self::User {
75            content: content.into(),
76        }
77    }
78
79    /// Create an assistant message
80    pub fn assistant<S: Into<String>>(content: S) -> Self {
81        Self::Assistant {
82            content: content.into(),
83        }
84    }
85
86    /// Create a tool result message.
87    pub fn tool<S1: Into<String>, S2: Into<String>>(
88        call_id: S1,
89        tool_name: S2,
90        result: Value,
91    ) -> Self {
92        Self::Tool {
93            call_id: call_id.into(),
94            tool_name: tool_name.into(),
95            result,
96        }
97    }
98}
99
100/// Selection cardinality for a provider-originated choice request.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum ChoiceSelectionMode {
104    Single,
105    Multiple,
106}
107
108/// One selectable item in a provider-originated choice request.
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110pub struct ChoiceItem {
111    pub id: String,
112    pub label: String,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub description: Option<String>,
115}
116
117/// A first-class model-originated choice request surfaced by the provider/runtime layer.
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
119pub struct ChoiceRequest {
120    pub prompt: String,
121    pub selection_mode: ChoiceSelectionMode,
122    pub items: Vec<ChoiceItem>,
123}
124
125impl ChoiceRequest {
126    /// Parse a choice request from a structured JSON value.
127    pub fn from_value(value: Value) -> Result<Self, serde_json::Error> {
128        serde_json::from_value(value)
129    }
130}
131
132/// Model-facing tool definition
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134pub struct ToolDefinition {
135    /// Unique tool name.
136    pub name: String,
137    /// Natural-language tool description.
138    pub description: String,
139    /// JSON Schema describing tool arguments.
140    pub input_schema: Value,
141}
142
143impl ToolDefinition {
144    /// Create a new tool definition.
145    pub fn new<S1: Into<String>, S2: Into<String>>(
146        name: S1,
147        description: S2,
148        input_schema: Value,
149    ) -> Self {
150        Self {
151            name: name.into(),
152            description: description.into(),
153            input_schema,
154        }
155    }
156}
157
158/// Tool choice policy
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
160#[serde(rename_all = "snake_case")]
161pub enum ToolPolicy {
162    /// No tools allowed
163    None,
164    /// Model can choose to use tools
165    #[default]
166    Auto,
167    /// Model must use a tool
168    Required,
169    /// Model must use a specific tool
170    Specific(String),
171}
172
173/// Normalized generation configuration
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
175pub struct GenerationConfig {
176    /// Temperature for sampling (0.0 to 2.0)
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub temperature: Option<f32>,
179    /// Maximum tokens to generate
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub max_tokens: Option<u32>,
182    /// Top-p sampling parameter
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub top_p: Option<f32>,
185    /// Stop sequences
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub stop: Option<Vec<String>>,
188}
189
190impl GenerationConfig {
191    /// Create an empty generation configuration.
192    pub fn new() -> Self {
193        Self::default()
194    }
195
196    /// Set the sampling temperature.
197    pub fn with_temperature(mut self, temp: f32) -> Self {
198        self.temperature = Some(temp);
199        self
200    }
201
202    /// Set the maximum output token count.
203    pub fn with_max_tokens(mut self, max: u32) -> Self {
204        self.max_tokens = Some(max);
205        self
206    }
207
208    /// Set the top-p sampling value.
209    pub fn with_top_p(mut self, top_p: f32) -> Self {
210        self.top_p = Some(top_p);
211        self
212    }
213}
214
215/// A completed tool call with structured JSON arguments
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct ToolCall {
218    /// Stable tool call identifier.
219    pub call_id: String,
220    /// Tool name selected by the model.
221    pub tool_name: String,
222    /// Parsed tool arguments.
223    pub arguments: Value,
224}
225
226impl ToolCall {
227    /// Create a normalized tool call record.
228    pub fn new<S1: Into<String>, S2: Into<String>>(
229        call_id: S1,
230        tool_name: S2,
231        arguments: Value,
232    ) -> Self {
233        Self {
234            call_id: call_id.into(),
235            tool_name: tool_name.into(),
236            arguments,
237        }
238    }
239}
240
241/// Normalized provider-reported token usage for a single inference request.
242///
243/// All fields are optional because provider families differ in what they
244/// return.  When present, each value represents the provider's cumulative
245/// snapshot for the current request, not an incremental delta.
246#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
247pub struct TokenUsage {
248    /// Input or prompt tokens reported by the provider.
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub input_tokens: Option<u64>,
251    /// Output or completion tokens reported by the provider.
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub output_tokens: Option<u64>,
254    /// Total tokens reported by the provider.
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub total_tokens: Option<u64>,
257    /// Cached input tokens reported by OpenAI-style providers.
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub cached_input_tokens: Option<u64>,
260    /// Cache creation input tokens reported by Anthropic-style providers.
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub cache_creation_input_tokens: Option<u64>,
263    /// Cache read input tokens reported by Anthropic-style providers.
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub cache_read_input_tokens: Option<u64>,
266    /// Reasoning or thinking output tokens reported by the provider.
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub reasoning_output_tokens: Option<u64>,
269}
270
271impl TokenUsage {
272    /// Create an empty usage snapshot.
273    pub fn new() -> Self {
274        Self::default()
275    }
276}
277
278/// Events emitted by the provider during streaming
279///
280/// ## Stream termination contract
281///
282/// - `Complete` is emitted **only** on successful stream termination.
283/// - If a provider encounters an unrecoverable error, the stream ends
284///   with `Error` and does **not** emit `Complete`.
285/// - `Status` events are informational and do not affect termination.
286/// - `Usage` events carry cumulative provider-reported token usage.  When
287///   multiple `Usage` events appear for the same request, the latest one
288///   supersedes earlier snapshots rather than being additive.
289#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
290#[serde(tag = "type", rename_all = "snake_case")]
291pub enum ProviderEvent {
292    /// Status update
293    Status { message: String },
294    /// Incremental text output
295    Output { content: String },
296    /// Completed tool call
297    ToolCall { call: ToolCall },
298    /// Structured model-originated choice request.
299    ChoiceRequest { request: ChoiceRequest },
300    /// Provider-reported token usage snapshot.
301    ///
302    /// Represents the provider's cumulative usage for the current request.
303    /// Consumers should treat later `Usage` events as superseding earlier
304    /// ones rather than adding them together.
305    Usage { usage: TokenUsage },
306    /// Stream completed successfully.
307    ///
308    /// This event is emitted exactly once per successful stream and is
309    /// never emitted after an unrecoverable error.
310    Complete,
311    /// Error occurred during streaming.
312    ///
313    /// Carries a structured [`ProviderError`](crate::ProviderError) so
314    /// downstream consumers can programmatically classify the failure
315    /// (authentication, rate-limit, transport, etc.).
316    ///
317    /// If this represents an unrecoverable error, the stream ends
318    /// without a subsequent `Complete` event.
319    Error { source: crate::ProviderError },
320}
321
322/// A runtime-owned record that is **not** model-visible.
323///
324/// Runtime records carry structured context (e.g. resolved interaction
325/// records, session metadata) that should be available to provider
326/// adapters for request assembly but must not be projected into the
327/// model-visible conversation transcript.
328#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub struct RuntimeRecord {
330    /// Stable record kind (e.g. "interaction", "session_state").
331    pub kind: String,
332    /// Structured payload.
333    pub payload: Value,
334}
335
336impl RuntimeRecord {
337    /// Create a new runtime record.
338    pub fn new<S: Into<String>>(kind: S, payload: Value) -> Self {
339        Self {
340            kind: kind.into(),
341            payload,
342        }
343    }
344}
345
346/// Inference context separating model-visible conversation from runtime-only state.
347///
348/// Provider adapters receive the full context but must only project the
349/// `transcript` into model-visible request fields. Runtime records may
350/// influence request assembly (e.g. system instructions, metadata headers)
351/// through explicit provider-specific mapping logic.
352#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
353pub struct InferenceContext {
354    /// Model-visible conversation transcript.
355    pub transcript: Transcript,
356    /// Runtime-only records that are not replayed into model context.
357    #[serde(default, skip_serializing_if = "Vec::is_empty")]
358    pub runtime_records: Vec<RuntimeRecord>,
359}
360
361impl InferenceContext {
362    /// Create an empty context.
363    pub fn new() -> Self {
364        Self::default()
365    }
366
367    /// Create a context with only a transcript (no runtime records).
368    pub fn from_transcript(transcript: Transcript) -> Self {
369        Self {
370            transcript,
371            runtime_records: vec![],
372        }
373    }
374
375    /// Add a runtime record.
376    pub fn add_record(&mut self, record: RuntimeRecord) {
377        self.runtime_records.push(record);
378    }
379}
380
381/// Semantic inference request
382#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
383pub struct InferenceRequest {
384    /// Model identifier
385    pub model: String,
386    /// Optional top-level instructions (system prompt)
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub instructions: Option<String>,
389    /// Inference context containing model-visible transcript and runtime-only records.
390    pub context: InferenceContext,
391    /// Available tools
392    #[serde(skip_serializing_if = "Vec::is_empty", default)]
393    pub tools: Vec<ToolDefinition>,
394    /// Tool usage policy
395    #[serde(default)]
396    pub tool_policy: ToolPolicy,
397    /// Generation settings
398    #[serde(default)]
399    pub generation: GenerationConfig,
400}
401
402impl InferenceRequest {
403    /// Create a new inference request for the provided model and transcript.
404    pub fn new<S: Into<String>>(model: S, transcript: Transcript) -> Self {
405        Self {
406            model: model.into(),
407            instructions: None,
408            context: InferenceContext::from_transcript(transcript),
409            tools: vec![],
410            tool_policy: ToolPolicy::default(),
411            generation: GenerationConfig::default(),
412        }
413    }
414
415    /// Validate that the model identifier is present and non-empty.
416    ///
417    /// Called by all provider adapters before constructing a request.
418    pub fn validate_model(&self) -> crate::ProviderResult<()> {
419        if self.model.trim().is_empty() {
420            return Err(crate::ProviderError::invalid_request(
421                "InferenceRequest.model must be a non-empty model identifier",
422            ));
423        }
424        Ok(())
425    }
426
427    /// Set top-level instructions for the request.
428    pub fn with_instructions<S: Into<String>>(mut self, instructions: S) -> Self {
429        self.instructions = Some(instructions.into());
430        self
431    }
432
433    /// Attach tool definitions to the request.
434    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
435        self.tools = tools;
436        self
437    }
438
439    /// Set the tool policy for the request.
440    pub fn with_tool_policy(mut self, policy: ToolPolicy) -> Self {
441        self.tool_policy = policy;
442        self
443    }
444
445    /// Set generation parameters for the request.
446    pub fn with_generation(mut self, generation: GenerationConfig) -> Self {
447        self.generation = generation;
448        self
449    }
450}