Skip to main content

systemprompt_models/ai/
request.rs

1//! Provider-agnostic inference request types.
2//!
3//! [`AiRequest`] is the unified request shape passed to any LLM provider: a
4//! sequence of [`AiMessage`]s (each optionally carrying multimodal
5//! [`AiContentPart`]s), the provider/model config, sampling params, available
6//! tools, and structured-output options. Build via [`AiRequestBuilder`] for the
7//! optional fields.
8
9use super::response_format::StructuredOutputOptions;
10use super::sampling::{ProviderConfig, SamplingParams};
11use super::tools::McpTool;
12use crate::execution::context::RequestContext;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum AiContentPart {
18    Text { text: String },
19    Image { mime_type: String, data: String },
20    Audio { mime_type: String, data: String },
21    Video { mime_type: String, data: String },
22}
23
24impl AiContentPart {
25    pub fn text(text: impl Into<String>) -> Self {
26        Self::Text { text: text.into() }
27    }
28
29    pub fn image(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
30        Self::Image {
31            mime_type: mime_type.into(),
32            data: data.into(),
33        }
34    }
35
36    pub fn audio(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
37        Self::Audio {
38            mime_type: mime_type.into(),
39            data: data.into(),
40        }
41    }
42
43    pub fn video(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
44        Self::Video {
45            mime_type: mime_type.into(),
46            data: data.into(),
47        }
48    }
49
50    pub const fn is_media(&self) -> bool {
51        matches!(
52            self,
53            Self::Image { .. } | Self::Audio { .. } | Self::Video { .. }
54        )
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AiMessage {
60    pub role: MessageRole,
61    pub content: String,
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub parts: Vec<AiContentPart>,
64}
65
66impl AiMessage {
67    pub fn user(content: impl Into<String>) -> Self {
68        Self {
69            role: MessageRole::User,
70            content: content.into(),
71            parts: Vec::new(),
72        }
73    }
74
75    pub fn assistant(content: impl Into<String>) -> Self {
76        Self {
77            role: MessageRole::Assistant,
78            content: content.into(),
79            parts: Vec::new(),
80        }
81    }
82
83    pub fn system(content: impl Into<String>) -> Self {
84        Self {
85            role: MessageRole::System,
86            content: content.into(),
87            parts: Vec::new(),
88        }
89    }
90
91    pub fn user_with_parts(content: impl Into<String>, parts: Vec<AiContentPart>) -> Self {
92        Self {
93            role: MessageRole::User,
94            content: content.into(),
95            parts,
96        }
97    }
98
99    pub fn has_media(&self) -> bool {
100        self.parts.iter().any(AiContentPart::is_media)
101    }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "lowercase")]
106pub enum MessageRole {
107    System,
108    User,
109    Assistant,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AiRequest {
114    pub messages: Vec<AiMessage>,
115    pub provider_config: ProviderConfig,
116    pub context: RequestContext,
117    pub sampling: Option<SamplingParams>,
118    pub tools: Option<Vec<McpTool>>,
119    pub structured_output: Option<StructuredOutputOptions>,
120    pub system_prompt: Option<String>,
121}
122
123impl AiRequest {
124    pub fn builder(
125        messages: Vec<AiMessage>,
126        provider: impl Into<String>,
127        model: impl Into<String>,
128        max_output_tokens: u32,
129        context: RequestContext,
130    ) -> AiRequestBuilder {
131        AiRequestBuilder::new(messages, provider, model, max_output_tokens, context)
132    }
133
134    pub fn has_tools(&self) -> bool {
135        self.tools.as_ref().is_some_and(|t| !t.is_empty())
136    }
137
138    pub fn provider(&self) -> &str {
139        &self.provider_config.provider
140    }
141
142    pub fn model(&self) -> &str {
143        &self.provider_config.model
144    }
145
146    pub const fn max_output_tokens(&self) -> u32 {
147        self.provider_config.max_output_tokens
148    }
149}
150
151#[derive(Debug)]
152pub struct AiRequestBuilder {
153    messages: Vec<AiMessage>,
154    provider_config: ProviderConfig,
155    context: RequestContext,
156    sampling: Option<SamplingParams>,
157    tools: Option<Vec<McpTool>>,
158    structured_output: Option<StructuredOutputOptions>,
159    system_prompt: Option<String>,
160}
161
162impl AiRequestBuilder {
163    pub fn new(
164        messages: Vec<AiMessage>,
165        provider: impl Into<String>,
166        model: impl Into<String>,
167        max_output_tokens: u32,
168        context: RequestContext,
169    ) -> Self {
170        Self {
171            messages,
172            provider_config: ProviderConfig::new(provider, model, max_output_tokens),
173            context,
174            sampling: None,
175            tools: None,
176            structured_output: None,
177            system_prompt: None,
178        }
179    }
180
181    pub fn with_sampling(mut self, sampling: SamplingParams) -> Self {
182        self.sampling = Some(sampling);
183        self
184    }
185
186    pub fn with_tools(mut self, tools: Vec<McpTool>) -> Self {
187        self.tools = Some(tools);
188        self
189    }
190
191    pub fn with_structured_output(mut self, options: StructuredOutputOptions) -> Self {
192        self.structured_output = Some(options);
193        self
194    }
195
196    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
197        self.system_prompt = Some(prompt.into());
198        self
199    }
200
201    pub fn build(self) -> AiRequest {
202        AiRequest {
203            messages: self.messages,
204            provider_config: self.provider_config,
205            context: self.context,
206            sampling: self.sampling,
207            tools: self.tools,
208            structured_output: self.structured_output,
209            system_prompt: self.system_prompt,
210        }
211    }
212}