1use super::response_format::StructuredOutputOptions;
10use super::sampling::{ProviderConfig, SamplingParams};
11use super::tools::McpTool;
12use crate::execution::context::RequestContext;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum AiContentPart {
18 Text { text: String },
19 Image { mime_type: String, data: String },
20 Audio { mime_type: String, data: String },
21 Video { mime_type: String, data: String },
22}
23
24impl AiContentPart {
25 pub fn text(text: impl Into<String>) -> Self {
26 Self::Text { text: text.into() }
27 }
28
29 pub fn image(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
30 Self::Image {
31 mime_type: mime_type.into(),
32 data: data.into(),
33 }
34 }
35
36 pub fn audio(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
37 Self::Audio {
38 mime_type: mime_type.into(),
39 data: data.into(),
40 }
41 }
42
43 pub fn video(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
44 Self::Video {
45 mime_type: mime_type.into(),
46 data: data.into(),
47 }
48 }
49
50 pub const fn is_media(&self) -> bool {
51 matches!(
52 self,
53 Self::Image { .. } | Self::Audio { .. } | Self::Video { .. }
54 )
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AiMessage {
60 pub role: MessageRole,
61 pub content: String,
62 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub parts: Vec<AiContentPart>,
64}
65
66impl AiMessage {
67 pub fn user(content: impl Into<String>) -> Self {
68 Self {
69 role: MessageRole::User,
70 content: content.into(),
71 parts: Vec::new(),
72 }
73 }
74
75 pub fn assistant(content: impl Into<String>) -> Self {
76 Self {
77 role: MessageRole::Assistant,
78 content: content.into(),
79 parts: Vec::new(),
80 }
81 }
82
83 pub fn system(content: impl Into<String>) -> Self {
84 Self {
85 role: MessageRole::System,
86 content: content.into(),
87 parts: Vec::new(),
88 }
89 }
90
91 pub fn user_with_parts(content: impl Into<String>, parts: Vec<AiContentPart>) -> Self {
92 Self {
93 role: MessageRole::User,
94 content: content.into(),
95 parts,
96 }
97 }
98
99 pub fn has_media(&self) -> bool {
100 self.parts.iter().any(AiContentPart::is_media)
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "lowercase")]
106pub enum MessageRole {
107 System,
108 User,
109 Assistant,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AiRequest {
114 pub messages: Vec<AiMessage>,
115 pub provider_config: ProviderConfig,
116 pub context: RequestContext,
117 pub sampling: Option<SamplingParams>,
118 pub tools: Option<Vec<McpTool>>,
119 pub structured_output: Option<StructuredOutputOptions>,
120 pub system_prompt: Option<String>,
121}
122
123impl AiRequest {
124 pub fn builder(
125 messages: Vec<AiMessage>,
126 provider: impl Into<String>,
127 model: impl Into<String>,
128 max_output_tokens: u32,
129 context: RequestContext,
130 ) -> AiRequestBuilder {
131 AiRequestBuilder::new(messages, provider, model, max_output_tokens, context)
132 }
133
134 pub fn has_tools(&self) -> bool {
135 self.tools.as_ref().is_some_and(|t| !t.is_empty())
136 }
137
138 pub fn provider(&self) -> &str {
139 &self.provider_config.provider
140 }
141
142 pub fn model(&self) -> &str {
143 &self.provider_config.model
144 }
145
146 pub const fn max_output_tokens(&self) -> u32 {
147 self.provider_config.max_output_tokens
148 }
149}
150
151#[derive(Debug)]
152pub struct AiRequestBuilder {
153 messages: Vec<AiMessage>,
154 provider_config: ProviderConfig,
155 context: RequestContext,
156 sampling: Option<SamplingParams>,
157 tools: Option<Vec<McpTool>>,
158 structured_output: Option<StructuredOutputOptions>,
159 system_prompt: Option<String>,
160}
161
162impl AiRequestBuilder {
163 pub fn new(
164 messages: Vec<AiMessage>,
165 provider: impl Into<String>,
166 model: impl Into<String>,
167 max_output_tokens: u32,
168 context: RequestContext,
169 ) -> Self {
170 Self {
171 messages,
172 provider_config: ProviderConfig::new(provider, model, max_output_tokens),
173 context,
174 sampling: None,
175 tools: None,
176 structured_output: None,
177 system_prompt: None,
178 }
179 }
180
181 pub fn with_sampling(mut self, sampling: SamplingParams) -> Self {
182 self.sampling = Some(sampling);
183 self
184 }
185
186 pub fn with_tools(mut self, tools: Vec<McpTool>) -> Self {
187 self.tools = Some(tools);
188 self
189 }
190
191 pub fn with_structured_output(mut self, options: StructuredOutputOptions) -> Self {
192 self.structured_output = Some(options);
193 self
194 }
195
196 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
197 self.system_prompt = Some(prompt.into());
198 self
199 }
200
201 pub fn build(self) -> AiRequest {
202 AiRequest {
203 messages: self.messages,
204 provider_config: self.provider_config,
205 context: self.context,
206 sampling: self.sampling,
207 tools: self.tools,
208 structured_output: self.structured_output,
209 system_prompt: self.system_prompt,
210 }
211 }
212}