1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
14#[serde(transparent)]
15pub struct NativeModelToolId(String);
16
17impl NativeModelToolId {
18 pub fn new(value: impl Into<String>) -> Result<Self, String> {
19 let value = value.into();
20 let trimmed = value.trim();
21 if trimmed.is_empty() {
22 return Err("native model tool id cannot be empty".to_string());
23 }
24 Ok(Self(trimmed.to_string()))
25 }
26
27 pub fn as_str(&self) -> &str {
28 &self.0
29 }
30}
31
32impl From<NativeModelToolId> for String {
33 fn from(value: NativeModelToolId) -> Self {
34 value.0
35 }
36}
37
38impl std::fmt::Display for NativeModelToolId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.write_str(self.as_str())
41 }
42}
43
44impl std::str::FromStr for NativeModelToolId {
45 type Err = String;
46
47 fn from_str(value: &str) -> Result<Self, Self::Err> {
48 Self::new(value)
49 }
50}
51
52impl<'de> Deserialize<'de> for NativeModelToolId {
53 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
54 where
55 D: serde::Deserializer<'de>,
56 {
57 let value = String::deserialize(deserializer)?;
58 Self::new(value).map_err(serde::de::Error::custom)
59 }
60}
61
62impl From<&str> for NativeModelToolId {
63 fn from(value: &str) -> Self {
64 Self::new(value).expect("static native model tool id should be valid")
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ProviderNativeModelToolSpec {
71 pub id: NativeModelToolId,
73 pub provider_type: String,
75 pub name: String,
77 pub description: String,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub parameters_schema: Option<serde_json::Value>,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub config_schema: Option<serde_json::Value>,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
89#[serde(rename_all = "snake_case")]
90pub enum NativeOperation {
91 GenerateImage,
92 EditImage,
93 GenerateVideo,
94 EditVideo,
95 ImageToVideo,
96 ReferenceToVideo,
97 ExtendVideo,
98 GenerateSpeech,
99 TranscribeAudio,
100 RealtimeVoiceAgent,
101}
102
103impl NativeOperation {
104 pub fn as_str(self) -> &'static str {
105 match self {
106 Self::GenerateImage => "generate_image",
107 Self::EditImage => "edit_image",
108 Self::GenerateVideo => "generate_video",
109 Self::EditVideo => "edit_video",
110 Self::ImageToVideo => "image_to_video",
111 Self::ReferenceToVideo => "reference_to_video",
112 Self::ExtendVideo => "extend_video",
113 Self::GenerateSpeech => "generate_speech",
114 Self::TranscribeAudio => "transcribe_audio",
115 Self::RealtimeVoiceAgent => "realtime_voice_agent",
116 }
117 }
118
119 pub fn tool_name(self) -> Option<&'static str> {
120 match self {
121 Self::RealtimeVoiceAgent => None,
122 operation => Some(operation.as_str()),
123 }
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
129#[serde(rename_all = "snake_case")]
130#[derive(Default)]
131pub enum MediaOutputFormat {
132 #[default]
133 Url,
134 Base64,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum MediaInputAsset {
141 Url { url: String },
142 DataUri { data_uri: String },
143 ProviderFileId { file_id: String },
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(tag = "type", rename_all = "snake_case")]
149pub enum MediaOutputAsset {
150 Url {
151 url: String,
152 mime_type: Option<String>,
153 },
154 Base64 {
155 data: String,
156 mime_type: Option<String>,
157 },
158 ProviderFileId {
159 file_id: String,
160 mime_type: Option<String>,
161 },
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct GenerateImageRequest {
167 pub model: String,
168 pub prompt: String,
169 #[serde(default, skip_serializing_if = "Option::is_none")]
170 pub n: Option<u32>,
171 #[serde(default, skip_serializing_if = "Option::is_none")]
172 pub size: Option<String>,
173 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub aspect_ratio: Option<String>,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub resolution: Option<String>,
177 #[serde(default)]
178 pub output_format: MediaOutputFormat,
179 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
180 pub provider_options: serde_json::Value,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct EditImageRequest {
186 pub model: String,
187 pub prompt: String,
188 pub image: MediaInputAsset,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub aspect_ratio: Option<String>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub resolution: Option<String>,
193 #[serde(default)]
194 pub output_format: MediaOutputFormat,
195 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
196 pub provider_options: serde_json::Value,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct GenerateVideoRequest {
202 pub model: String,
203 pub prompt: String,
204 #[serde(default, skip_serializing_if = "Option::is_none")]
205 pub duration_seconds: Option<u32>,
206 #[serde(default, skip_serializing_if = "Option::is_none")]
207 pub aspect_ratio: Option<String>,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
209 pub resolution: Option<String>,
210 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
211 pub provider_options: serde_json::Value,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct EditVideoRequest {
217 pub model: String,
218 pub prompt: String,
219 pub video: MediaInputAsset,
220 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
221 pub provider_options: serde_json::Value,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ImageToVideoRequest {
227 pub model: String,
228 pub prompt: String,
229 pub image: MediaInputAsset,
230 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub duration_seconds: Option<u32>,
232 #[serde(default, skip_serializing_if = "Option::is_none")]
233 pub aspect_ratio: Option<String>,
234 #[serde(default, skip_serializing_if = "Option::is_none")]
235 pub resolution: Option<String>,
236 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
237 pub provider_options: serde_json::Value,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct ReferenceToVideoRequest {
243 pub model: String,
244 pub prompt: String,
245 pub reference_images: Vec<MediaInputAsset>,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub duration_seconds: Option<u32>,
248 #[serde(default, skip_serializing_if = "Option::is_none")]
249 pub aspect_ratio: Option<String>,
250 #[serde(default, skip_serializing_if = "Option::is_none")]
251 pub resolution: Option<String>,
252 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
253 pub provider_options: serde_json::Value,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct ExtendVideoRequest {
259 pub model: String,
260 pub prompt: String,
261 pub video: MediaInputAsset,
262 #[serde(default, skip_serializing_if = "Option::is_none")]
263 pub duration_seconds: Option<u32>,
264 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
265 pub provider_options: serde_json::Value,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct GenerateSpeechRequest {
271 pub model: String,
272 pub text: String,
273 pub voice: String,
274 #[serde(default, skip_serializing_if = "Option::is_none")]
275 pub language: Option<String>,
276 #[serde(default, skip_serializing_if = "Option::is_none")]
277 pub output_format: Option<String>,
278 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
279 pub provider_options: serde_json::Value,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct TranscribeAudioRequest {
285 pub model: String,
286 pub audio: MediaInputAsset,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
288 pub language: Option<String>,
289 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
290 pub provider_options: serde_json::Value,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295#[serde(tag = "operation", content = "request", rename_all = "snake_case")]
296pub enum NativeMediaRequest {
297 GenerateImage(GenerateImageRequest),
298 EditImage(EditImageRequest),
299 GenerateVideo(GenerateVideoRequest),
300 EditVideo(EditVideoRequest),
301 ImageToVideo(ImageToVideoRequest),
302 ReferenceToVideo(ReferenceToVideoRequest),
303 ExtendVideo(ExtendVideoRequest),
304 GenerateSpeech(GenerateSpeechRequest),
305 TranscribeAudio(TranscribeAudioRequest),
306}
307
308impl NativeMediaRequest {
309 pub fn operation(&self) -> NativeOperation {
310 match self {
311 Self::GenerateImage(_) => NativeOperation::GenerateImage,
312 Self::EditImage(_) => NativeOperation::EditImage,
313 Self::GenerateVideo(_) => NativeOperation::GenerateVideo,
314 Self::EditVideo(_) => NativeOperation::EditVideo,
315 Self::ImageToVideo(_) => NativeOperation::ImageToVideo,
316 Self::ReferenceToVideo(_) => NativeOperation::ReferenceToVideo,
317 Self::ExtendVideo(_) => NativeOperation::ExtendVideo,
318 Self::GenerateSpeech(_) => NativeOperation::GenerateSpeech,
319 Self::TranscribeAudio(_) => NativeOperation::TranscribeAudio,
320 }
321 }
322}
323
324#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub enum NativeMediaJobStatus {
328 Queued,
329 Running,
330 Completed,
331 Failed,
332 Expired,
333 Cancelled,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct NativeMediaJob {
339 pub provider: String,
340 pub operation: NativeOperation,
341 pub job_id: String,
342 pub status: NativeMediaJobStatus,
343 #[serde(default, skip_serializing_if = "Option::is_none")]
344 pub model: Option<String>,
345 #[serde(default, skip_serializing_if = "Option::is_none")]
346 pub metadata: Option<serde_json::Value>,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
351#[serde(tag = "type", rename_all = "snake_case")]
352pub enum NativeMediaResponse {
353 Assets {
354 assets: Vec<MediaOutputAsset>,
355 #[serde(default, skip_serializing_if = "Option::is_none")]
356 metadata: Option<serde_json::Value>,
357 },
358 Job {
359 job: NativeMediaJob,
360 },
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct ModelNativeCapabilities {
366 pub model_pattern: String,
367 pub tools: Vec<NativeToolSpec>,
368}
369
370impl ModelNativeCapabilities {
371 pub fn operations(&self) -> impl Iterator<Item = NativeOperation> + '_ {
372 self.tools.iter().map(|tool| tool.capability)
373 }
374}
375
376#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
378#[serde(tag = "mode", rename_all = "snake_case")]
379pub enum NativeExecutionMode {
380 Immediate,
381 AsyncJob { poll_supported: bool },
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct NativeToolSpec {
387 pub capability: NativeOperation,
388 pub tool_name: String,
389 pub description: String,
390 pub parameters_schema: serde_json::Value,
391 pub execution: NativeExecutionMode,
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct ProviderNativeCapabilities {
397 pub provider: String,
398 #[serde(default)]
399 pub model_tools: Vec<ProviderNativeModelToolSpec>,
400 pub models: Vec<ModelNativeCapabilities>,
401}
402
403#[async_trait]
405pub trait NativeCapabilitiesProvider: Send + Sync {
406 fn native_capabilities(&self) -> ProviderNativeCapabilities;
407
408 async fn submit_media(
409 &self,
410 request: NativeMediaRequest,
411 ) -> anyhow::Result<NativeMediaResponse>;
412
413 async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
414 let _ = job;
415 anyhow::bail!("provider does not support polling native media jobs")
416 }
417}
418
419pub(crate) fn media_input_schema() -> serde_json::Value {
420 serde_json::json!({
421 "oneOf": [
422 {
423 "type": "object",
424 "properties": {
425 "type": {"const": "url"},
426 "url": {"type": "string"}
427 },
428 "required": ["type", "url"]
429 },
430 {
431 "type": "object",
432 "properties": {
433 "type": {"const": "data_uri"},
434 "data_uri": {"type": "string"}
435 },
436 "required": ["type", "data_uri"]
437 },
438 {
439 "type": "object",
440 "properties": {
441 "type": {"const": "provider_file_id"},
442 "file_id": {"type": "string"}
443 },
444 "required": ["type", "file_id"]
445 }
446 ]
447 })
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn media_request_reports_operation() {
456 let request = NativeMediaRequest::GenerateImage(GenerateImageRequest {
457 model: "example-image-model".to_string(),
458 prompt: "draw a diagram".to_string(),
459 n: None,
460 size: None,
461 aspect_ratio: None,
462 resolution: None,
463 output_format: MediaOutputFormat::Url,
464 provider_options: serde_json::Value::Null,
465 });
466
467 assert_eq!(request.operation(), NativeOperation::GenerateImage);
468 }
469
470 #[test]
471 fn native_model_tool_id_serializes_as_valid_string() {
472 let id: NativeModelToolId =
473 serde_json::from_value(serde_json::json!(" provider_search ")).expect("valid id");
474
475 assert_eq!(id.as_str(), "provider_search");
476 assert_eq!(
477 serde_json::to_value(&id).unwrap(),
478 serde_json::json!("provider_search")
479 );
480 assert!(serde_json::from_value::<NativeModelToolId>(serde_json::json!("")).is_err());
481 }
482}