1use super::response_format::StructuredOutputOptions;
2use super::sampling::{ProviderConfig, SamplingParams};
3use super::tools::McpTool;
4use crate::execution::context::RequestContext;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(tag = "type", rename_all = "snake_case")]
9pub enum AiContentPart {
10 Text { text: String },
11 Image { mime_type: String, data: String },
12 Audio { mime_type: String, data: String },
13 Video { mime_type: String, data: String },
14}
15
16impl AiContentPart {
17 pub fn text(text: impl Into<String>) -> Self {
18 Self::Text { text: text.into() }
19 }
20
21 pub fn image(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
22 Self::Image {
23 mime_type: mime_type.into(),
24 data: data.into(),
25 }
26 }
27
28 pub fn audio(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
29 Self::Audio {
30 mime_type: mime_type.into(),
31 data: data.into(),
32 }
33 }
34
35 pub fn video(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
36 Self::Video {
37 mime_type: mime_type.into(),
38 data: data.into(),
39 }
40 }
41
42 pub const fn is_media(&self) -> bool {
43 matches!(
44 self,
45 Self::Image { .. } | Self::Audio { .. } | Self::Video { .. }
46 )
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct AiMessage {
52 pub role: MessageRole,
53 pub content: String,
54 #[serde(default, skip_serializing_if = "Vec::is_empty")]
55 pub parts: Vec<AiContentPart>,
56}
57
58impl AiMessage {
59 pub fn user(content: impl Into<String>) -> Self {
60 Self {
61 role: MessageRole::User,
62 content: content.into(),
63 parts: Vec::new(),
64 }
65 }
66
67 pub fn assistant(content: impl Into<String>) -> Self {
68 Self {
69 role: MessageRole::Assistant,
70 content: content.into(),
71 parts: Vec::new(),
72 }
73 }
74
75 pub fn system(content: impl Into<String>) -> Self {
76 Self {
77 role: MessageRole::System,
78 content: content.into(),
79 parts: Vec::new(),
80 }
81 }
82
83 pub fn user_with_parts(content: impl Into<String>, parts: Vec<AiContentPart>) -> Self {
84 Self {
85 role: MessageRole::User,
86 content: content.into(),
87 parts,
88 }
89 }
90
91 pub fn has_media(&self) -> bool {
92 self.parts.iter().any(AiContentPart::is_media)
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "lowercase")]
98pub enum MessageRole {
99 System,
100 User,
101 Assistant,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct AiRequest {
106 pub messages: Vec<AiMessage>,
107 pub provider_config: ProviderConfig,
108 pub context: RequestContext,
109 pub sampling: Option<SamplingParams>,
110 pub tools: Option<Vec<McpTool>>,
111 pub structured_output: Option<StructuredOutputOptions>,
112 pub system_prompt: Option<String>,
113}
114
115impl AiRequest {
116 pub fn builder(
117 messages: Vec<AiMessage>,
118 provider: impl Into<String>,
119 model: impl Into<String>,
120 max_output_tokens: u32,
121 context: RequestContext,
122 ) -> AiRequestBuilder {
123 AiRequestBuilder::new(messages, provider, model, max_output_tokens, context)
124 }
125
126 pub fn has_tools(&self) -> bool {
127 self.tools.as_ref().is_some_and(|t| !t.is_empty())
128 }
129
130 pub fn provider(&self) -> &str {
131 &self.provider_config.provider
132 }
133
134 pub fn model(&self) -> &str {
135 &self.provider_config.model
136 }
137
138 pub const fn max_output_tokens(&self) -> u32 {
139 self.provider_config.max_output_tokens
140 }
141}
142
143#[derive(Debug)]
144pub struct AiRequestBuilder {
145 messages: Vec<AiMessage>,
146 provider_config: ProviderConfig,
147 context: RequestContext,
148 sampling: Option<SamplingParams>,
149 tools: Option<Vec<McpTool>>,
150 structured_output: Option<StructuredOutputOptions>,
151 system_prompt: Option<String>,
152}
153
154impl AiRequestBuilder {
155 pub fn new(
156 messages: Vec<AiMessage>,
157 provider: impl Into<String>,
158 model: impl Into<String>,
159 max_output_tokens: u32,
160 context: RequestContext,
161 ) -> Self {
162 Self {
163 messages,
164 provider_config: ProviderConfig::new(provider, model, max_output_tokens),
165 context,
166 sampling: None,
167 tools: None,
168 structured_output: None,
169 system_prompt: None,
170 }
171 }
172
173 pub fn with_sampling(mut self, sampling: SamplingParams) -> Self {
174 self.sampling = Some(sampling);
175 self
176 }
177
178 pub fn with_tools(mut self, tools: Vec<McpTool>) -> Self {
179 self.tools = Some(tools);
180 self
181 }
182
183 pub fn with_structured_output(mut self, options: StructuredOutputOptions) -> Self {
184 self.structured_output = Some(options);
185 self
186 }
187
188 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
189 self.system_prompt = Some(prompt.into());
190 self
191 }
192
193 pub fn build(self) -> AiRequest {
194 AiRequest {
195 messages: self.messages,
196 provider_config: self.provider_config,
197 context: self.context,
198 sampling: self.sampling,
199 tools: self.tools,
200 structured_output: self.structured_output,
201 system_prompt: self.system_prompt,
202 }
203 }
204}