Skip to main content

nenjo_models/
xai.rs

1//! xAI provider.
2//!
3//! Chat uses xAI's OpenAI-compatible chat completions surface. Provider-native
4//! media operations use xAI-specific endpoints under `https://api.x.ai/v1`.
5
6use async_trait::async_trait;
7use futures_util::StreamExt;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::{HashMap, HashSet};
12
13use crate::compatible::{AuthStyle, OpenAiCompatibleProvider};
14use crate::native::{
15    EditImageRequest, EditVideoRequest, ExtendVideoRequest, GenerateVideoRequest,
16    ImageToVideoRequest, MediaInputAsset, MediaOutputAsset, MediaOutputFormat,
17    ModelNativeCapabilities, NativeCapabilitiesProvider, NativeExecutionMode, NativeMediaJob,
18    NativeMediaJobStatus, NativeMediaRequest, NativeMediaResponse, NativeModelToolId,
19    NativeOperation, NativeToolSpec, ProviderNativeCapabilities, ProviderNativeModelToolSpec,
20    ReferenceToVideoRequest, media_input_schema,
21};
22use crate::traits::{
23    ChatMessage, ChatRequest, ChatResponse, ModelProvider, ProviderStreamEvent, ProviderToolTrace,
24    TokenUsage, ToolCall,
25};
26
27pub const XAI_DEFAULT_BASE_URL: &str = "https://api.x.ai/v1";
28
29pub struct XAiProvider {
30    api_key: Option<String>,
31    base_url: String,
32    chat: OpenAiCompatibleProvider,
33    client: Client,
34}
35
36#[derive(Debug, Serialize)]
37struct ImageGenerationRequest<'a> {
38    model: &'a str,
39    prompt: &'a str,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    n: Option<u32>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    response_format: Option<&'static str>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    aspect_ratio: Option<&'a str>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    resolution: Option<&'a str>,
48}
49
50#[derive(Debug, Serialize)]
51struct ImageEditRequest<'a> {
52    model: &'a str,
53    prompt: &'a str,
54    image: XaiMediaInput,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    response_format: Option<&'static str>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    aspect_ratio: Option<&'a str>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    resolution: Option<&'a str>,
61}
62
63#[derive(Debug, Serialize)]
64struct VideoRequest<'a> {
65    model: &'a str,
66    prompt: &'a str,
67    #[serde(rename = "duration", skip_serializing_if = "Option::is_none")]
68    duration_seconds: Option<u32>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    aspect_ratio: Option<&'a str>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    resolution: Option<&'a str>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    image: Option<XaiMediaInput>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    reference_images: Option<Vec<XaiMediaInput>>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    video: Option<XaiMediaInput>,
79}
80
81#[derive(Debug, Serialize)]
82struct XaiMediaInput {
83    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
84    kind: Option<&'static str>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    url: Option<String>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    file_id: Option<String>,
89}
90
91#[derive(Debug, Deserialize)]
92struct ImageGenerationResponse {
93    data: Vec<ImageGenerationData>,
94}
95
96#[derive(Debug, Deserialize)]
97struct ImageGenerationData {
98    #[serde(default)]
99    url: Option<String>,
100    #[serde(default)]
101    b64_json: Option<String>,
102    #[serde(default)]
103    revised_prompt: Option<String>,
104}
105
106#[derive(Debug, Deserialize)]
107struct VideoStartResponse {
108    request_id: String,
109}
110
111#[derive(Debug, Deserialize)]
112struct VideoPollResponse {
113    status: String,
114    #[serde(default)]
115    video: Option<VideoAsset>,
116    #[serde(default)]
117    error: Option<XaiError>,
118}
119
120#[derive(Debug, Deserialize)]
121struct VideoAsset {
122    url: String,
123    #[serde(default)]
124    duration: Option<f64>,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128struct XaiError {
129    #[serde(default)]
130    code: Option<String>,
131    #[serde(default)]
132    message: Option<String>,
133}
134
135#[derive(Debug, Serialize)]
136struct ResponsesRequest {
137    model: String,
138    input: Vec<ResponsesInput>,
139    tools: Vec<ResponsesTool>,
140    temperature: f64,
141    stream: bool,
142}
143
144#[derive(Debug, Serialize)]
145#[serde(untagged)]
146enum ResponsesInput {
147    Message {
148        role: String,
149        content: String,
150    },
151    FunctionCall {
152        #[serde(rename = "type")]
153        kind: &'static str,
154        call_id: String,
155        name: String,
156        arguments: String,
157    },
158    FunctionCallOutput {
159        #[serde(rename = "type")]
160        kind: &'static str,
161        call_id: String,
162        output: String,
163    },
164}
165
166#[derive(Debug, Serialize, PartialEq)]
167struct ResponsesTool {
168    #[serde(rename = "type")]
169    kind: String,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    name: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    description: Option<String>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    parameters: Option<Value>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179struct ResponsesResponse {
180    #[serde(default)]
181    output: Vec<ResponsesOutput>,
182    #[serde(default)]
183    output_text: Option<String>,
184    #[serde(default)]
185    usage: Option<ResponsesUsage>,
186}
187
188#[derive(Debug, Clone, Deserialize)]
189struct ResponsesOutput {
190    #[serde(default)]
191    id: Option<String>,
192    #[serde(default)]
193    call_id: Option<String>,
194    #[serde(rename = "type", default)]
195    kind: Option<String>,
196    #[serde(default)]
197    name: Option<String>,
198    #[serde(default)]
199    arguments: Option<Value>,
200    #[serde(default)]
201    content: Vec<ResponsesContent>,
202    #[serde(default)]
203    status: Option<String>,
204    #[serde(flatten)]
205    extra: serde_json::Map<String, Value>,
206}
207
208#[derive(Debug, Clone, Deserialize, Serialize)]
209struct ResponsesContent {
210    #[serde(rename = "type", default)]
211    kind: Option<String>,
212    #[serde(default)]
213    text: Option<String>,
214    #[serde(default)]
215    annotations: Vec<Value>,
216}
217
218#[derive(Debug, Clone, Deserialize)]
219struct ResponsesUsage {
220    #[serde(default, alias = "prompt_tokens")]
221    input_tokens: u64,
222    #[serde(default, alias = "completion_tokens")]
223    output_tokens: u64,
224}
225
226fn xai_media_input(asset: MediaInputAsset, image_edit_input: bool) -> XaiMediaInput {
227    let kind = match &asset {
228        MediaInputAsset::ProviderFileId { .. } => None,
229        MediaInputAsset::Url { .. } | MediaInputAsset::DataUri { .. } => {
230            image_edit_input.then_some("image_url")
231        }
232    };
233    match asset {
234        MediaInputAsset::Url { url } => XaiMediaInput {
235            kind,
236            url: Some(url),
237            file_id: None,
238        },
239        MediaInputAsset::DataUri { data_uri } => XaiMediaInput {
240            kind,
241            url: Some(data_uri),
242            file_id: None,
243        },
244        MediaInputAsset::ProviderFileId { file_id } => XaiMediaInput {
245            kind,
246            url: None,
247            file_id: Some(file_id),
248        },
249    }
250}
251
252fn xai_image_tool_spec(operation: NativeOperation) -> NativeToolSpec {
253    let mut properties = json!({
254        "prompt": {"type": "string"},
255        "n": {"type": "integer", "minimum": 1},
256        "aspect_ratio": {
257            "type": "string",
258            "enum": [
259                "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3",
260                "2:1", "1:2", "19.5:9", "9:19.5", "20:9", "9:20", "auto"
261            ]
262        },
263        "resolution": {"type": "string", "enum": ["1k", "2k"]},
264        "output_format": {"type": "string", "enum": ["url", "base64"]},
265        "provider_options": {
266            "type": "object",
267            "properties": {},
268            "additionalProperties": false
269        }
270    });
271    let required = match operation {
272        NativeOperation::GenerateImage => vec!["prompt"],
273        NativeOperation::EditImage => {
274            properties["image"] = media_input_schema();
275            vec!["prompt", "image"]
276        }
277        other => panic!("unsupported xAI image operation {other:?}"),
278    };
279
280    NativeToolSpec {
281        capability: operation,
282        tool_name: operation.tool_name().unwrap().to_string(),
283        description: match operation {
284            NativeOperation::GenerateImage => {
285                "Generate an image with the configured xAI image model."
286            }
287            NativeOperation::EditImage => "Edit an image with the configured xAI image model.",
288            _ => unreachable!(),
289        }
290        .to_string(),
291        parameters_schema: json!({
292            "type": "object",
293            "properties": properties,
294            "required": required
295        }),
296        execution: NativeExecutionMode::Immediate,
297    }
298}
299
300fn xai_video_provider_options() -> Value {
301    json!({
302        "type": "object",
303        "properties": {
304            "poll_timeout_ms": {
305                "type": "integer",
306                "minimum": 1
307            }
308        },
309        "additionalProperties": false
310    })
311}
312
313fn xai_video_base_properties() -> Value {
314    json!({
315        "prompt": {"type": "string"},
316        "duration_seconds": {"type": "integer", "minimum": 1},
317        "aspect_ratio": {"type": "string", "enum": ["16:9", "9:16", "1:1"]},
318        "resolution": {"type": "string", "enum": ["480p", "720p"]},
319        "provider_options": xai_video_provider_options()
320    })
321}
322
323fn xai_video_tool_spec(operation: NativeOperation) -> NativeToolSpec {
324    let mut properties = xai_video_base_properties();
325    let required = match operation {
326        NativeOperation::GenerateVideo => vec!["prompt"],
327        NativeOperation::ImageToVideo => {
328            properties["image"] = media_input_schema();
329            vec!["prompt", "image"]
330        }
331        NativeOperation::ReferenceToVideo => {
332            properties["reference_images"] = json!({
333                "type": "array",
334                "items": media_input_schema(),
335                "minItems": 1,
336                "maxItems": 7
337            });
338            properties["duration_seconds"]["maximum"] = json!(10);
339            vec!["prompt", "reference_images"]
340        }
341        NativeOperation::EditVideo => {
342            properties = json!({
343                "prompt": {"type": "string"},
344                "video": media_input_schema(),
345                "provider_options": xai_video_provider_options()
346            });
347            vec!["prompt", "video"]
348        }
349        NativeOperation::ExtendVideo => {
350            properties = json!({
351                "prompt": {"type": "string"},
352                "video": media_input_schema(),
353                "duration_seconds": {
354                    "type": "integer",
355                    "minimum": 2,
356                    "maximum": 10
357                },
358                "provider_options": xai_video_provider_options()
359            });
360            vec!["prompt", "video"]
361        }
362        other => panic!("unsupported xAI video operation {other:?}"),
363    };
364
365    NativeToolSpec {
366        capability: operation,
367        tool_name: operation.tool_name().unwrap().to_string(),
368        description: match operation {
369            NativeOperation::GenerateVideo => "Start an asynchronous xAI video generation job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call generate_video again for the same prompt unless the user explicitly asks for another independent video.",
370            NativeOperation::EditVideo => "Start an asynchronous xAI video editing job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call edit_video again for the same request unless the user explicitly asks for another independent edit.",
371            NativeOperation::ImageToVideo => "Start an asynchronous xAI image-to-video job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call image_to_video again for the same request unless the user explicitly asks for another independent video.",
372            NativeOperation::ReferenceToVideo => "Start an asynchronous xAI reference-to-video job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call reference_to_video again for the same request unless the user explicitly asks for another independent video.",
373            NativeOperation::ExtendVideo => "Start an asynchronous xAI video extension job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call extend_video again for the same request unless the user explicitly asks for another independent extension.",
374            _ => unreachable!(),
375        }
376        .to_string(),
377        parameters_schema: json!({
378            "type": "object",
379            "properties": properties,
380            "required": required
381        }),
382        execution: NativeExecutionMode::AsyncJob {
383            poll_supported: true,
384        },
385    }
386}
387
388fn xai_video_status(status: &str) -> anyhow::Result<NativeMediaJobStatus> {
389    match status {
390        "pending" => Ok(NativeMediaJobStatus::Running),
391        "done" => Ok(NativeMediaJobStatus::Completed),
392        "expired" => Ok(NativeMediaJobStatus::Expired),
393        "failed" => Ok(NativeMediaJobStatus::Failed),
394        other => anyhow::bail!("unknown xAI video job status '{other}'"),
395    }
396}
397
398fn first_nonempty(text: Option<&str>) -> Option<String> {
399    text.and_then(|value| {
400        let trimmed = value.trim();
401        if trimmed.is_empty() {
402            None
403        } else {
404            Some(trimmed.to_string())
405        }
406    })
407}
408
409fn xai_native_model_tool_specs() -> Vec<ProviderNativeModelToolSpec> {
410    vec![
411        ProviderNativeModelToolSpec {
412            id: NativeModelToolId::from("web_search"),
413            provider_type: "web_search".to_string(),
414            name: "web_search".to_string(),
415            description: "Provider-native xAI web search for current web results and citations."
416                .to_string(),
417            parameters_schema: Some(json!({
418                "type": "object",
419                "properties": {},
420                "additionalProperties": false
421            })),
422            config_schema: None,
423        },
424        ProviderNativeModelToolSpec {
425            id: NativeModelToolId::from("x_search"),
426            provider_type: "x_search".to_string(),
427            name: "x_search".to_string(),
428            description:
429                "Provider-native xAI X search for posts, discussions, and current activity on X."
430                    .to_string(),
431            parameters_schema: Some(json!({
432                "type": "object",
433                "properties": {},
434                "additionalProperties": false
435            })),
436            config_schema: None,
437        },
438    ]
439}
440
441fn xai_native_model_tool_spec(tool_id: &NativeModelToolId) -> Option<ProviderNativeModelToolSpec> {
442    xai_native_model_tool_specs()
443        .into_iter()
444        .find(|spec| spec.id == *tool_id)
445}
446
447fn native_responses_tools(
448    native_tools: &[NativeModelToolId],
449    local_tools: Option<&[crate::ToolSpec]>,
450) -> anyhow::Result<Vec<ResponsesTool>> {
451    let mut tools = Vec::with_capacity(native_tools.len() + local_tools.map_or(0, <[_]>::len));
452    for tool_id in native_tools {
453        let tool = xai_native_model_tool_spec(tool_id)
454            .ok_or_else(|| anyhow::anyhow!("xAI does not support native model tool '{tool_id}'"))?;
455        tools.push(ResponsesTool {
456            kind: tool.provider_type,
457            name: None,
458            description: None,
459            parameters: None,
460        });
461    }
462
463    if let Some(local_tools) = local_tools {
464        tools.extend(local_tools.iter().map(|tool| ResponsesTool {
465            kind: "function".to_string(),
466            name: Some(crate::sanitize_tool_name(&tool.name)),
467            description: Some(tool.description.clone()),
468            parameters: Some(tool.parameters.clone()),
469        }));
470    }
471
472    Ok(tools)
473}
474
475fn responses_input(messages: &[ChatMessage]) -> Vec<ResponsesInput> {
476    let mut input = Vec::with_capacity(messages.len());
477
478    for message in messages {
479        if message.role == "assistant"
480            && let Ok(value) = serde_json::from_str::<Value>(&message.content)
481            && let Some(tool_calls_value) = value.get("tool_calls")
482            && let Ok(tool_calls) =
483                serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
484        {
485            if let Some(content) = value
486                .get("content")
487                .and_then(Value::as_str)
488                .and_then(|text| first_nonempty(Some(text)))
489            {
490                input.push(ResponsesInput::Message {
491                    role: "assistant".to_string(),
492                    content,
493                });
494            }
495
496            input.extend(
497                tool_calls
498                    .into_iter()
499                    .map(|call| ResponsesInput::FunctionCall {
500                        kind: "function_call",
501                        call_id: call.id,
502                        name: call.name,
503                        arguments: call.arguments,
504                    }),
505            );
506            continue;
507        }
508
509        if message.role == "tool"
510            && let Ok(value) = serde_json::from_str::<Value>(&message.content)
511            && let Some(call_id) = value.get("tool_call_id").and_then(Value::as_str)
512        {
513            let output = value
514                .get("content")
515                .and_then(Value::as_str)
516                .unwrap_or_default()
517                .to_string();
518            input.push(ResponsesInput::FunctionCallOutput {
519                kind: "function_call_output",
520                call_id: call_id.to_string(),
521                output,
522            });
523            continue;
524        }
525
526        input.push(ResponsesInput::Message {
527            role: message.role.clone(),
528            content: message.content.clone(),
529        });
530    }
531
532    input
533}
534
535fn responses_text(response: &ResponsesResponse) -> Option<String> {
536    if let Some(text) = first_nonempty(response.output_text.as_deref()) {
537        return Some(text);
538    }
539
540    for item in &response.output {
541        for content in &item.content {
542            if content.kind.as_deref() == Some("output_text")
543                && let Some(text) = first_nonempty(content.text.as_deref())
544            {
545                return Some(text);
546            }
547        }
548    }
549
550    for item in &response.output {
551        for content in &item.content {
552            if let Some(text) = first_nonempty(content.text.as_deref()) {
553                return Some(text);
554            }
555        }
556    }
557
558    None
559}
560
561fn responses_tool_calls(response: &ResponsesResponse) -> Vec<ToolCall> {
562    response
563        .output
564        .iter()
565        .filter(|item| item.kind.as_deref() == Some("function_call"))
566        .filter_map(|item| {
567            let name = item.name.clone()?;
568            let arguments = match item.arguments.as_ref() {
569                Some(Value::String(value)) => value.clone(),
570                Some(value) => value.to_string(),
571                None => "{}".to_string(),
572            };
573            Some(ToolCall {
574                id: item
575                    .call_id
576                    .clone()
577                    .or_else(|| item.id.clone())
578                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
579                name,
580                arguments,
581            })
582        })
583        .collect()
584}
585
586fn xai_native_tool_name(output_kind: &str) -> Option<&'static str> {
587    match output_kind {
588        "web_search_call" => Some("web_search"),
589        "x_search_call" => Some("x_search"),
590        "code_interpreter_call" => Some("code_interpreter"),
591        "file_search_call" => Some("file_search"),
592        "mcp_call" => Some("mcp"),
593        _ => None,
594    }
595}
596
597fn provider_tool_trace_from_responses_output(item: &ResponsesOutput) -> Option<ProviderToolTrace> {
598    let kind = item.kind.as_deref()?;
599    let name = xai_native_tool_name(kind)?;
600    let id = item
601        .call_id
602        .clone()
603        .or_else(|| item.id.clone())
604        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
605
606    let mut input = serde_json::Map::new();
607    input.insert(
608        "response_item_type".to_string(),
609        Value::String(kind.to_string()),
610    );
611    if let Some(status) = &item.status {
612        input.insert("status".to_string(), Value::String(status.clone()));
613    }
614    if let Some(arguments) = &item.arguments {
615        input.insert("arguments".to_string(), arguments.clone());
616    }
617    if let Some(name) = &item.name {
618        input.insert("name".to_string(), Value::String(name.clone()));
619    }
620    for key in [
621        "action",
622        "query",
623        "queries",
624        "server_label",
625        "server_url",
626        "vector_store_ids",
627    ] {
628        if let Some(value) = item.extra.get(key) {
629            input.insert(key.to_string(), value.clone());
630        }
631    }
632
633    let mut output = item.extra.clone();
634    output.remove("action");
635    output.remove("query");
636    output.remove("queries");
637    output.remove("server_label");
638    output.remove("server_url");
639    output.remove("vector_store_ids");
640    if !item.content.is_empty() {
641        output.insert(
642            "content".to_string(),
643            serde_json::to_value(&item.content).unwrap_or(Value::Null),
644        );
645    }
646
647    let mut citations = Vec::new();
648    for key in ["citations", "sources", "results"] {
649        if let Some(value) = item.extra.get(key) {
650            citations.push(value.clone());
651        }
652    }
653    for content in &item.content {
654        citations.extend(content.annotations.iter().cloned());
655    }
656
657    Some(ProviderToolTrace {
658        id,
659        name: name.to_string(),
660        provider: "xai".to_string(),
661        input: Value::Object(input),
662        output: (!output.is_empty()).then_some(Value::Object(output)),
663        citations,
664    })
665}
666
667fn responses_provider_tool_traces(response: &ResponsesResponse) -> Vec<ProviderToolTrace> {
668    response
669        .output
670        .iter()
671        .filter_map(provider_tool_trace_from_responses_output)
672        .collect()
673}
674
675#[derive(Default)]
676struct ResponsesStreamState {
677    text: String,
678    output: HashMap<String, ResponsesOutput>,
679    final_response: Option<ResponsesResponse>,
680    started_provider_tools: HashSet<String>,
681    completed_provider_tools: HashSet<String>,
682}
683
684impl ResponsesStreamState {
685    fn into_response(self) -> ResponsesResponse {
686        self.final_response.unwrap_or_else(|| ResponsesResponse {
687            output: self.output.into_values().collect(),
688            output_text: (!self.text.is_empty()).then_some(self.text),
689            usage: None,
690        })
691    }
692}
693
694fn stream_event_type(value: &Value) -> Option<&str> {
695    value.get("type").and_then(Value::as_str)
696}
697
698fn stream_text_delta(value: &Value) -> Option<&str> {
699    let kind = stream_event_type(value).unwrap_or_default();
700    if kind.contains("output_text.delta") || kind.contains("text.delta") {
701        return value.get("delta").and_then(Value::as_str);
702    }
703    None
704}
705
706fn stream_response(value: &Value) -> Option<ResponsesResponse> {
707    let kind = stream_event_type(value).unwrap_or_default();
708    if !(kind.ends_with(".completed") || kind == "response.completed") {
709        return None;
710    }
711    value
712        .get("response")
713        .cloned()
714        .and_then(|response| serde_json::from_value(response).ok())
715}
716
717fn stream_output_item(value: &Value) -> Option<ResponsesOutput> {
718    for key in ["item", "output_item", "response_item"] {
719        if let Some(item) = value.get(key)
720            && let Ok(output) = serde_json::from_value::<ResponsesOutput>(item.clone())
721        {
722            return Some(output);
723        }
724    }
725    serde_json::from_value::<ResponsesOutput>(value.clone()).ok()
726}
727
728fn stream_tool_phase(value: &Value, output: &ResponsesOutput) -> Option<&'static str> {
729    let kind = stream_event_type(value).unwrap_or_default();
730    if kind.contains(".added") || kind.contains(".in_progress") || kind.contains(".started") {
731        return Some("started");
732    }
733    if kind.contains(".done") || kind.contains(".completed") {
734        return Some("completed");
735    }
736    match output.status.as_deref() {
737        Some("in_progress" | "running" | "searching" | "started") => Some("started"),
738        Some("completed" | "done") => Some("completed"),
739        _ => None,
740    }
741}
742
743fn native_kind_from_stream_type(kind: &str) -> Option<&'static str> {
744    [
745        "web_search_call",
746        "x_search_call",
747        "code_interpreter_call",
748        "file_search_call",
749        "mcp_call",
750    ]
751    .into_iter()
752    .find(|candidate| kind.contains(candidate))
753}
754
755fn stream_raw_provider_tool_trace(value: &Value) -> Option<ProviderToolTrace> {
756    let kind = stream_event_type(value)?;
757    let response_item_type = native_kind_from_stream_type(kind)?;
758    let name = xai_native_tool_name(response_item_type)?;
759    let id = value
760        .get("call_id")
761        .or_else(|| value.get("item_id"))
762        .or_else(|| value.get("id"))
763        .and_then(Value::as_str)
764        .map(ToString::to_string)
765        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
766
767    let mut input = serde_json::Map::new();
768    input.insert(
769        "response_item_type".to_string(),
770        Value::String(response_item_type.to_string()),
771    );
772    input.insert(
773        "stream_event_type".to_string(),
774        Value::String(kind.to_string()),
775    );
776    if let Some(status) = value.get("status").and_then(Value::as_str) {
777        input.insert("status".to_string(), Value::String(status.to_string()));
778    }
779    for key in ["action", "query", "queries", "server_label", "server_url"] {
780        if let Some(field) = value.get(key) {
781            input.insert(key.to_string(), field.clone());
782        }
783    }
784
785    Some(ProviderToolTrace {
786        id,
787        name: name.to_string(),
788        provider: "xai".to_string(),
789        input: Value::Object(input),
790        output: None,
791        citations: Vec::new(),
792    })
793}
794
795fn stream_raw_provider_tool_phase(value: &Value) -> Option<&'static str> {
796    let kind = stream_event_type(value)?;
797    native_kind_from_stream_type(kind)?;
798    if kind.contains(".done") || kind.contains(".completed") {
799        Some("completed")
800    } else {
801        Some("started")
802    }
803}
804
805fn handle_responses_stream_value(
806    value: Value,
807    state: &mut ResponsesStreamState,
808    events: &tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
809) {
810    if let Some(delta) = stream_text_delta(&value)
811        && !delta.is_empty()
812    {
813        state.text.push_str(delta);
814        let _ = events.send(ProviderStreamEvent::TextDelta(delta.to_string()));
815    }
816
817    if let Some(response) = stream_response(&value) {
818        state.final_response = Some(response);
819    }
820
821    if let Some(output) = stream_output_item(&value)
822        && let Some(trace) = provider_tool_trace_from_responses_output(&output)
823    {
824        let phase = stream_tool_phase(&value, &output);
825        state.output.insert(trace.id.clone(), output);
826        match phase {
827            Some("started") => {
828                if state.started_provider_tools.insert(trace.id.clone()) {
829                    let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace));
830                }
831            }
832            Some("completed") => {
833                if state.started_provider_tools.insert(trace.id.clone()) {
834                    let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace.clone()));
835                }
836                if state.completed_provider_tools.insert(trace.id.clone()) {
837                    let _ = events.send(ProviderStreamEvent::ProviderToolCompleted(trace));
838                }
839            }
840            _ => {}
841        }
842        return;
843    }
844
845    if let Some(trace) = stream_raw_provider_tool_trace(&value) {
846        match stream_raw_provider_tool_phase(&value) {
847            Some("completed") => {
848                if state.started_provider_tools.insert(trace.id.clone()) {
849                    let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace.clone()));
850                }
851                if state.completed_provider_tools.insert(trace.id.clone()) {
852                    let _ = events.send(ProviderStreamEvent::ProviderToolCompleted(trace));
853                }
854            }
855            Some("started") => {
856                if state.started_provider_tools.insert(trace.id.clone()) {
857                    let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace));
858                }
859            }
860            _ => {}
861        }
862    }
863}
864
865impl XAiProvider {
866    pub fn new(api_key: Option<&str>) -> Self {
867        Self::with_base_url(api_key, XAI_DEFAULT_BASE_URL)
868    }
869
870    pub fn with_base_url(api_key: Option<&str>, base_url: &str) -> Self {
871        let normalized_base_url = base_url.trim_end_matches('/').to_string();
872        Self {
873            api_key: api_key.map(ToString::to_string),
874            base_url: normalized_base_url.clone(),
875            chat: OpenAiCompatibleProvider::new(
876                "xai",
877                &normalized_base_url,
878                api_key,
879                AuthStyle::Bearer,
880            ),
881            client: Client::builder()
882                .timeout(std::time::Duration::from_secs(120))
883                .connect_timeout(std::time::Duration::from_secs(10))
884                .build()
885                .unwrap_or_else(|_| Client::new()),
886        }
887    }
888
889    fn endpoint(&self, path: &str) -> String {
890        format!("{}/{}", self.base_url, path.trim_start_matches('/'))
891    }
892
893    fn api_key(&self) -> anyhow::Result<&str> {
894        self.api_key.as_deref().ok_or_else(|| {
895            anyhow::anyhow!("xAI API key not set. Set XAI_API_KEY or edit config.toml.")
896        })
897    }
898
899    async fn chat_with_native_model_tools(
900        &self,
901        request: ChatRequest<'_>,
902        model: &str,
903        temperature: f64,
904        native_tools: &[NativeModelToolId],
905    ) -> anyhow::Result<ChatResponse> {
906        let api_key = self.api_key()?;
907        let body = ResponsesRequest {
908            model: model.to_string(),
909            input: responses_input(request.messages),
910            tools: native_responses_tools(native_tools, request.tools)?,
911            temperature,
912            stream: false,
913        };
914
915        let response = self
916            .client
917            .post(self.endpoint("/responses"))
918            .header("Authorization", format!("Bearer {api_key}"))
919            .json(&body)
920            .send()
921            .await?;
922
923        if !response.status().is_success() {
924            return Err(crate::api_error("xAI", response).await);
925        }
926
927        let body_text = response.text().await?;
928        let response: ResponsesResponse = serde_json::from_str(&body_text).map_err(|error| {
929            anyhow::anyhow!(
930                "xAI Responses API decode error: {error}\nBody: {}",
931                &body_text[..body_text.len().min(500)]
932            )
933        })?;
934
935        let usage = response
936            .usage
937            .as_ref()
938            .map(|usage| TokenUsage {
939                input_tokens: usage.input_tokens,
940                output_tokens: usage.output_tokens,
941            })
942            .unwrap_or_default();
943        let text = responses_text(&response);
944        let tool_calls = responses_tool_calls(&response);
945        let provider_tool_calls = responses_provider_tool_traces(&response);
946
947        Ok(ChatResponse {
948            text,
949            tool_calls,
950            provider_tool_calls,
951            usage,
952        })
953    }
954
955    async fn chat_with_native_model_tools_streaming(
956        &self,
957        request: ChatRequest<'_>,
958        model: &str,
959        temperature: f64,
960        native_tools: &[NativeModelToolId],
961        events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
962    ) -> anyhow::Result<ChatResponse> {
963        let api_key = self.api_key()?;
964        let body = ResponsesRequest {
965            model: model.to_string(),
966            input: responses_input(request.messages),
967            tools: native_responses_tools(native_tools, request.tools)?,
968            temperature,
969            stream: true,
970        };
971        let response = self
972            .client
973            .post(self.endpoint("/responses"))
974            .header("Authorization", format!("Bearer {api_key}"))
975            .json(&body)
976            .send()
977            .await?;
978
979        if !response.status().is_success() {
980            return Err(crate::api_error("xAI", response).await);
981        }
982
983        let mut state = ResponsesStreamState::default();
984        let mut stream = response.bytes_stream();
985        let mut buffer = String::new();
986
987        while let Some(chunk) = stream.next().await {
988            let chunk = chunk?;
989            buffer.push_str(&String::from_utf8_lossy(&chunk));
990            if buffer.contains("\r\n") {
991                buffer = buffer.replace("\r\n", "\n");
992            }
993
994            while let Some(split_at) = buffer.find("\n\n") {
995                let frame = buffer[..split_at].to_string();
996                buffer = buffer[split_at + 2..].to_string();
997
998                for line in frame.lines() {
999                    let Some(data) = line.strip_prefix("data:") else {
1000                        continue;
1001                    };
1002                    let data = data.trim();
1003                    if data.is_empty() || data == "[DONE]" {
1004                        continue;
1005                    }
1006                    if let Ok(value) = serde_json::from_str::<Value>(data) {
1007                        handle_responses_stream_value(value, &mut state, &events);
1008                    }
1009                }
1010            }
1011        }
1012
1013        if !buffer.trim().is_empty() {
1014            for line in buffer.lines() {
1015                let Some(data) = line.strip_prefix("data:") else {
1016                    continue;
1017                };
1018                let data = data.trim();
1019                if data.is_empty() || data == "[DONE]" {
1020                    continue;
1021                }
1022                if let Ok(value) = serde_json::from_str::<Value>(data) {
1023                    handle_responses_stream_value(value, &mut state, &events);
1024                }
1025            }
1026        }
1027
1028        let response = state.into_response();
1029        let usage = response
1030            .usage
1031            .as_ref()
1032            .map(|usage| TokenUsage {
1033                input_tokens: usage.input_tokens,
1034                output_tokens: usage.output_tokens,
1035            })
1036            .unwrap_or_default();
1037        let text = responses_text(&response);
1038        let tool_calls = responses_tool_calls(&response);
1039        let provider_tool_calls = responses_provider_tool_traces(&response);
1040
1041        Ok(ChatResponse {
1042            text,
1043            tool_calls,
1044            provider_tool_calls,
1045            usage,
1046        })
1047    }
1048
1049    async fn generate_image(
1050        &self,
1051        request: crate::native::GenerateImageRequest,
1052    ) -> anyhow::Result<NativeMediaResponse> {
1053        let api_key = self.api_key()?;
1054
1055        let response_format = match request.output_format {
1056            MediaOutputFormat::Url => None,
1057            MediaOutputFormat::Base64 => Some("b64_json"),
1058        };
1059        let body = ImageGenerationRequest {
1060            model: &request.model,
1061            prompt: &request.prompt,
1062            n: request.n,
1063            response_format,
1064            aspect_ratio: request.aspect_ratio.as_deref(),
1065            resolution: request.resolution.as_deref(),
1066        };
1067
1068        let response = self
1069            .client
1070            .post(self.endpoint("/images/generations"))
1071            .header("Authorization", format!("Bearer {api_key}"))
1072            .json(&body)
1073            .send()
1074            .await?;
1075
1076        if !response.status().is_success() {
1077            return Err(crate::api_error("xAI", response).await);
1078        }
1079
1080        let images: ImageGenerationResponse = response.json().await?;
1081        let mut assets = Vec::new();
1082        let mut revised_prompts = Vec::new();
1083
1084        for image in images.data {
1085            if let Some(prompt) = image.revised_prompt {
1086                revised_prompts.push(prompt);
1087            }
1088            if let Some(url) = image.url {
1089                assets.push(MediaOutputAsset::Url {
1090                    url,
1091                    mime_type: Some("image/jpeg".to_string()),
1092                });
1093            } else if let Some(data) = image.b64_json {
1094                assets.push(MediaOutputAsset::Base64 {
1095                    data,
1096                    mime_type: Some("image/jpeg".to_string()),
1097                });
1098            }
1099        }
1100
1101        if assets.is_empty() {
1102            anyhow::bail!("xAI image generation returned no assets");
1103        }
1104
1105        let metadata = if revised_prompts.is_empty() {
1106            None
1107        } else {
1108            Some(serde_json::json!({ "revised_prompts": revised_prompts }))
1109        };
1110
1111        Ok(NativeMediaResponse::Assets { assets, metadata })
1112    }
1113
1114    async fn edit_image(&self, request: EditImageRequest) -> anyhow::Result<NativeMediaResponse> {
1115        let api_key = self.api_key()?;
1116        let response_format = match request.output_format {
1117            MediaOutputFormat::Url => None,
1118            MediaOutputFormat::Base64 => Some("b64_json"),
1119        };
1120        let body = ImageEditRequest {
1121            model: &request.model,
1122            prompt: &request.prompt,
1123            image: xai_media_input(request.image, true),
1124            response_format,
1125            aspect_ratio: request.aspect_ratio.as_deref(),
1126            resolution: request.resolution.as_deref(),
1127        };
1128
1129        let response = self
1130            .client
1131            .post(self.endpoint("/images/edits"))
1132            .header("Authorization", format!("Bearer {api_key}"))
1133            .json(&body)
1134            .send()
1135            .await?;
1136
1137        if !response.status().is_success() {
1138            return Err(crate::api_error("xAI", response).await);
1139        }
1140
1141        self.parse_image_response(response).await
1142    }
1143
1144    async fn parse_image_response(
1145        &self,
1146        response: reqwest::Response,
1147    ) -> anyhow::Result<NativeMediaResponse> {
1148        let images: ImageGenerationResponse = response.json().await?;
1149        let mut assets = Vec::new();
1150        let mut revised_prompts = Vec::new();
1151
1152        for image in images.data {
1153            if let Some(prompt) = image.revised_prompt {
1154                revised_prompts.push(prompt);
1155            }
1156            if let Some(url) = image.url {
1157                assets.push(MediaOutputAsset::Url {
1158                    url,
1159                    mime_type: Some("image/jpeg".to_string()),
1160                });
1161            } else if let Some(data) = image.b64_json {
1162                assets.push(MediaOutputAsset::Base64 {
1163                    data,
1164                    mime_type: Some("image/jpeg".to_string()),
1165                });
1166            }
1167        }
1168
1169        if assets.is_empty() {
1170            anyhow::bail!("xAI image operation returned no assets");
1171        }
1172
1173        let metadata = if revised_prompts.is_empty() {
1174            None
1175        } else {
1176            Some(json!({ "revised_prompts": revised_prompts }))
1177        };
1178
1179        Ok(NativeMediaResponse::Assets { assets, metadata })
1180    }
1181
1182    async fn start_video_job<T: Serialize + ?Sized>(
1183        &self,
1184        path: &str,
1185        operation: NativeOperation,
1186        model: &str,
1187        body: &T,
1188    ) -> anyhow::Result<NativeMediaResponse> {
1189        let api_key = self.api_key()?;
1190        let response = self
1191            .client
1192            .post(self.endpoint(path))
1193            .header("Authorization", format!("Bearer {api_key}"))
1194            .json(body)
1195            .send()
1196            .await?;
1197
1198        if !response.status().is_success() {
1199            return Err(crate::api_error("xAI", response).await);
1200        }
1201
1202        let started: VideoStartResponse = response.json().await?;
1203        Ok(NativeMediaResponse::Job {
1204            job: NativeMediaJob {
1205                provider: "xai".to_string(),
1206                operation,
1207                job_id: started.request_id,
1208                status: NativeMediaJobStatus::Queued,
1209                model: Some(model.to_string()),
1210                metadata: None,
1211            },
1212        })
1213    }
1214
1215    async fn generate_video(
1216        &self,
1217        request: GenerateVideoRequest,
1218    ) -> anyhow::Result<NativeMediaResponse> {
1219        let body = VideoRequest {
1220            model: &request.model,
1221            prompt: &request.prompt,
1222            duration_seconds: request.duration_seconds,
1223            aspect_ratio: request.aspect_ratio.as_deref(),
1224            resolution: request.resolution.as_deref(),
1225            image: None,
1226            reference_images: None,
1227            video: None,
1228        };
1229        self.start_video_job(
1230            "/videos/generations",
1231            NativeOperation::GenerateVideo,
1232            &request.model,
1233            &body,
1234        )
1235        .await
1236    }
1237
1238    async fn image_to_video(
1239        &self,
1240        request: ImageToVideoRequest,
1241    ) -> anyhow::Result<NativeMediaResponse> {
1242        let body = VideoRequest {
1243            model: &request.model,
1244            prompt: &request.prompt,
1245            duration_seconds: request.duration_seconds,
1246            aspect_ratio: request.aspect_ratio.as_deref(),
1247            resolution: request.resolution.as_deref(),
1248            image: Some(xai_media_input(request.image, false)),
1249            reference_images: None,
1250            video: None,
1251        };
1252        self.start_video_job(
1253            "/videos/generations",
1254            NativeOperation::ImageToVideo,
1255            &request.model,
1256            &body,
1257        )
1258        .await
1259    }
1260
1261    async fn reference_to_video(
1262        &self,
1263        request: ReferenceToVideoRequest,
1264    ) -> anyhow::Result<NativeMediaResponse> {
1265        let body = VideoRequest {
1266            model: &request.model,
1267            prompt: &request.prompt,
1268            duration_seconds: request.duration_seconds,
1269            aspect_ratio: request.aspect_ratio.as_deref(),
1270            resolution: request.resolution.as_deref(),
1271            image: None,
1272            reference_images: Some(
1273                request
1274                    .reference_images
1275                    .into_iter()
1276                    .map(|asset| xai_media_input(asset, false))
1277                    .collect(),
1278            ),
1279            video: None,
1280        };
1281        self.start_video_job(
1282            "/videos/generations",
1283            NativeOperation::ReferenceToVideo,
1284            &request.model,
1285            &body,
1286        )
1287        .await
1288    }
1289
1290    async fn edit_video(&self, request: EditVideoRequest) -> anyhow::Result<NativeMediaResponse> {
1291        let body = VideoRequest {
1292            model: &request.model,
1293            prompt: &request.prompt,
1294            duration_seconds: None,
1295            aspect_ratio: None,
1296            resolution: None,
1297            image: None,
1298            reference_images: None,
1299            video: Some(xai_media_input(request.video, false)),
1300        };
1301        self.start_video_job(
1302            "/videos/edits",
1303            NativeOperation::EditVideo,
1304            &request.model,
1305            &body,
1306        )
1307        .await
1308    }
1309
1310    async fn extend_video(
1311        &self,
1312        request: ExtendVideoRequest,
1313    ) -> anyhow::Result<NativeMediaResponse> {
1314        let body = VideoRequest {
1315            model: &request.model,
1316            prompt: &request.prompt,
1317            duration_seconds: request.duration_seconds,
1318            aspect_ratio: None,
1319            resolution: None,
1320            image: None,
1321            reference_images: None,
1322            video: Some(xai_media_input(request.video, false)),
1323        };
1324        self.start_video_job(
1325            "/videos/extensions",
1326            NativeOperation::ExtendVideo,
1327            &request.model,
1328            &body,
1329        )
1330        .await
1331    }
1332}
1333
1334#[async_trait]
1335impl ModelProvider for XAiProvider {
1336    async fn chat(
1337        &self,
1338        request: ChatRequest<'_>,
1339        model: &str,
1340        temperature: f64,
1341    ) -> anyhow::Result<ChatResponse> {
1342        if let Some(native_tools) = request.native_tools
1343            && !native_tools.is_empty()
1344        {
1345            return self
1346                .chat_with_native_model_tools(request, model, temperature, native_tools)
1347                .await;
1348        }
1349        self.chat.chat(request, model, temperature).await
1350    }
1351
1352    async fn chat_stream(
1353        &self,
1354        request: ChatRequest<'_>,
1355        model: &str,
1356        temperature: f64,
1357        events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1358    ) -> anyhow::Result<ChatResponse> {
1359        if let Some(native_tools) = request.native_tools
1360            && !native_tools.is_empty()
1361        {
1362            return self
1363                .chat_with_native_model_tools_streaming(
1364                    request,
1365                    model,
1366                    temperature,
1367                    native_tools,
1368                    events,
1369                )
1370                .await;
1371        }
1372        self.chat
1373            .chat_stream(request, model, temperature, events)
1374            .await
1375    }
1376
1377    fn context_window(&self, model: &str) -> Option<usize> {
1378        self.chat.context_window(model)
1379    }
1380
1381    fn supports_native_tools(&self) -> bool {
1382        true
1383    }
1384
1385    fn supports_developer_role(&self, model: &str) -> bool {
1386        self.chat.supports_developer_role(model)
1387    }
1388
1389    fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
1390        Some(NativeCapabilitiesProvider::native_capabilities(self))
1391    }
1392
1393    async fn submit_media(
1394        &self,
1395        request: NativeMediaRequest,
1396    ) -> anyhow::Result<NativeMediaResponse> {
1397        NativeCapabilitiesProvider::submit_media(self, request).await
1398    }
1399
1400    async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
1401        NativeCapabilitiesProvider::poll_media_job(self, job).await
1402    }
1403
1404    async fn warmup(&self) -> anyhow::Result<()> {
1405        self.chat.warmup().await
1406    }
1407}
1408
1409#[async_trait]
1410impl NativeCapabilitiesProvider for XAiProvider {
1411    fn native_capabilities(&self) -> ProviderNativeCapabilities {
1412        ProviderNativeCapabilities {
1413            provider: "xai".to_string(),
1414            model_tools: xai_native_model_tool_specs(),
1415            models: vec![
1416                ModelNativeCapabilities {
1417                    model_pattern: "grok-imagine-image-quality".to_string(),
1418                    tools: vec![
1419                        xai_image_tool_spec(NativeOperation::GenerateImage),
1420                        xai_image_tool_spec(NativeOperation::EditImage),
1421                    ],
1422                },
1423                ModelNativeCapabilities {
1424                    model_pattern: "grok-imagine-video*".to_string(),
1425                    tools: vec![
1426                        xai_video_tool_spec(NativeOperation::GenerateVideo),
1427                        xai_video_tool_spec(NativeOperation::EditVideo),
1428                        xai_video_tool_spec(NativeOperation::ImageToVideo),
1429                        xai_video_tool_spec(NativeOperation::ReferenceToVideo),
1430                        xai_video_tool_spec(NativeOperation::ExtendVideo),
1431                    ],
1432                },
1433            ],
1434        }
1435    }
1436
1437    async fn submit_media(
1438        &self,
1439        request: NativeMediaRequest,
1440    ) -> anyhow::Result<NativeMediaResponse> {
1441        let operation = request.operation();
1442        match request {
1443            NativeMediaRequest::GenerateImage(request) => self.generate_image(request).await,
1444            NativeMediaRequest::EditImage(request) => self.edit_image(request).await,
1445            NativeMediaRequest::GenerateVideo(request) => self.generate_video(request).await,
1446            NativeMediaRequest::EditVideo(request) => self.edit_video(request).await,
1447            NativeMediaRequest::ImageToVideo(request) => self.image_to_video(request).await,
1448            NativeMediaRequest::ReferenceToVideo(request) => self.reference_to_video(request).await,
1449            NativeMediaRequest::ExtendVideo(request) => self.extend_video(request).await,
1450            NativeMediaRequest::GenerateSpeech(_) | NativeMediaRequest::TranscribeAudio(_) => {
1451                anyhow::bail!(
1452                    "xAI native operation {operation:?} is declared but not implemented in this pass"
1453                )
1454            }
1455        }
1456    }
1457
1458    async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
1459        let api_key = self.api_key()?;
1460        let response = self
1461            .client
1462            .get(self.endpoint(format!("/videos/{}", job.job_id).as_str()))
1463            .header("Authorization", format!("Bearer {api_key}"))
1464            .send()
1465            .await?;
1466
1467        if !response.status().is_success() {
1468            return Err(crate::api_error("xAI", response).await);
1469        }
1470
1471        let polled: VideoPollResponse = response.json().await?;
1472        let status = xai_video_status(&polled.status)?;
1473        if status == NativeMediaJobStatus::Completed {
1474            let video = polled.video.ok_or_else(|| {
1475                anyhow::anyhow!("xAI video job {} completed without a video", job.job_id)
1476            })?;
1477            let metadata = video
1478                .duration
1479                .map(|duration| json!({ "duration_seconds": duration }));
1480            return Ok(NativeMediaResponse::Assets {
1481                assets: vec![MediaOutputAsset::Url {
1482                    url: video.url,
1483                    mime_type: Some("video/mp4".to_string()),
1484                }],
1485                metadata,
1486            });
1487        }
1488
1489        let metadata = polled
1490            .error
1491            .and_then(|error| serde_json::to_value(error).ok());
1492        Ok(NativeMediaResponse::Job {
1493            job: NativeMediaJob {
1494                provider: job.provider.clone(),
1495                operation: job.operation,
1496                job_id: job.job_id.clone(),
1497                status,
1498                model: job.model.clone(),
1499                metadata,
1500            },
1501        })
1502    }
1503}
1504
1505#[cfg(test)]
1506mod tests {
1507    use super::*;
1508
1509    #[test]
1510    fn creates_with_default_base_url() {
1511        let provider = XAiProvider::new(Some("xai-key"));
1512        assert_eq!(provider.base_url, XAI_DEFAULT_BASE_URL);
1513    }
1514
1515    #[test]
1516    fn capabilities_include_xai_video_modes() {
1517        let provider = XAiProvider::new(None);
1518        let capabilities = NativeCapabilitiesProvider::native_capabilities(&provider);
1519        let video = capabilities
1520            .models
1521            .iter()
1522            .find(|model| model.model_pattern == "grok-imagine-video*")
1523            .expect("video capability");
1524
1525        assert!(
1526            video
1527                .operations()
1528                .any(|op| op == NativeOperation::ImageToVideo)
1529        );
1530        assert!(
1531            video
1532                .operations()
1533                .any(|op| op == NativeOperation::ReferenceToVideo)
1534        );
1535        assert!(
1536            video
1537                .operations()
1538                .any(|op| op == NativeOperation::ExtendVideo)
1539        );
1540    }
1541
1542    #[test]
1543    fn xai_video_status_maps_to_native_status() {
1544        assert_eq!(
1545            xai_video_status("pending").expect("pending"),
1546            NativeMediaJobStatus::Running
1547        );
1548        assert_eq!(
1549            xai_video_status("done").expect("done"),
1550            NativeMediaJobStatus::Completed
1551        );
1552        assert_eq!(
1553            xai_video_status("expired").expect("expired"),
1554            NativeMediaJobStatus::Expired
1555        );
1556        assert_eq!(
1557            xai_video_status("failed").expect("failed"),
1558            NativeMediaJobStatus::Failed
1559        );
1560    }
1561
1562    #[test]
1563    fn xai_video_poll_response_matches_rest_done_shape() {
1564        let response: VideoPollResponse = serde_json::from_value(json!({
1565            "status": "done",
1566            "video": {
1567                "url": "https://vidgen.x.ai/example/video.mp4",
1568                "duration": 8,
1569                "respect_moderation": true
1570            },
1571            "model": "grok-imagine-video"
1572        }))
1573        .expect("poll response should parse");
1574
1575        assert_eq!(response.status, "done");
1576        let video = response.video.expect("video asset");
1577        assert_eq!(video.url, "https://vidgen.x.ai/example/video.mp4");
1578        assert_eq!(video.duration, Some(8.0));
1579    }
1580
1581    #[test]
1582    fn xai_image_edit_input_uses_image_url_shape() {
1583        let input = xai_media_input(
1584            MediaInputAsset::Url {
1585                url: "https://example.com/image.png".to_string(),
1586            },
1587            true,
1588        );
1589        let value = serde_json::to_value(input).expect("serialize");
1590
1591        assert_eq!(value["type"], "image_url");
1592        assert_eq!(value["url"], "https://example.com/image.png");
1593    }
1594
1595    #[test]
1596    fn xai_responses_tools_include_native_and_local_tools() {
1597        let tools = native_responses_tools(
1598            &[
1599                NativeModelToolId::from("web_search"),
1600                NativeModelToolId::from("x_search"),
1601            ],
1602            Some(&[crate::ToolSpec {
1603                name: "shell".to_string(),
1604                description: "Run a shell command.".to_string(),
1605                parameters: json!({
1606                    "type": "object",
1607                    "properties": {
1608                        "cmd": { "type": "string" }
1609                    },
1610                    "required": ["cmd"]
1611                }),
1612                category: crate::ToolCategory::Write,
1613            }]),
1614        )
1615        .expect("supported tools");
1616
1617        assert_eq!(tools[0].kind, "web_search");
1618        assert_eq!(tools[1].kind, "x_search");
1619        assert_eq!(tools[2].kind, "function");
1620        assert_eq!(tools[2].name.as_deref(), Some("shell"));
1621    }
1622
1623    #[test]
1624    fn xai_responses_tools_reject_unknown_native_tool_ids() {
1625        let error = native_responses_tools(&[NativeModelToolId::from("unknown_tool")], None)
1626            .expect_err("unsupported tool should fail");
1627
1628        assert!(error.to_string().contains("unknown_tool"));
1629    }
1630
1631    #[test]
1632    fn xai_responses_extracts_function_calls() {
1633        let response: ResponsesResponse = serde_json::from_value(json!({
1634            "output": [
1635                {
1636                    "type": "function_call",
1637                    "call_id": "call_123",
1638                    "name": "shell",
1639                    "arguments": "{\"cmd\":\"date\"}"
1640                }
1641            ],
1642            "usage": {
1643                "input_tokens": 5,
1644                "output_tokens": 3
1645            }
1646        }))
1647        .expect("responses payload should parse");
1648
1649        let calls = responses_tool_calls(&response);
1650        assert_eq!(calls.len(), 1);
1651        assert_eq!(calls[0].id, "call_123");
1652        assert_eq!(calls[0].name, "shell");
1653        assert_eq!(calls[0].arguments, "{\"cmd\":\"date\"}");
1654    }
1655
1656    #[test]
1657    fn xai_responses_extracts_provider_native_tool_traces() {
1658        let response: ResponsesResponse = serde_json::from_value(json!({
1659            "output": [
1660                {
1661                    "id": "ws_123",
1662                    "type": "web_search_call",
1663                    "status": "completed",
1664                    "action": {
1665                        "type": "search",
1666                        "query": "latest xAI models"
1667                    },
1668                    "results": [
1669                        { "title": "xAI Docs", "url": "https://docs.x.ai/developers/models" }
1670                    ]
1671                },
1672                {
1673                    "type": "message",
1674                    "content": [
1675                        {
1676                            "type": "output_text",
1677                            "text": "xAI has new models.",
1678                            "annotations": [
1679                                { "type": "url_citation", "url": "https://docs.x.ai/developers/models" }
1680                            ]
1681                        }
1682                    ]
1683                }
1684            ]
1685        }))
1686        .expect("responses payload should parse");
1687
1688        let traces = responses_provider_tool_traces(&response);
1689        assert_eq!(traces.len(), 1);
1690        assert_eq!(traces[0].id, "ws_123");
1691        assert_eq!(traces[0].name, "web_search");
1692        assert_eq!(traces[0].provider, "xai");
1693        assert_eq!(traces[0].input["status"], "completed");
1694        assert!(traces[0].output.is_some());
1695        assert_eq!(traces[0].citations.len(), 1);
1696    }
1697
1698    #[test]
1699    fn xai_stream_parser_emits_provider_tool_start_and_completion() {
1700        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1701        let mut state = ResponsesStreamState::default();
1702
1703        handle_responses_stream_value(
1704            json!({
1705                "type": "response.output_item.added",
1706                "item": {
1707                    "id": "ws_123",
1708                    "type": "web_search_call",
1709                    "status": "in_progress",
1710                    "action": {
1711                        "type": "search",
1712                        "query": "latest xAI models"
1713                    }
1714                }
1715            }),
1716            &mut state,
1717            &tx,
1718        );
1719        handle_responses_stream_value(
1720            json!({
1721                "type": "response.output_item.done",
1722                "item": {
1723                    "id": "ws_123",
1724                    "type": "web_search_call",
1725                    "status": "completed",
1726                    "results": [
1727                        { "title": "xAI Docs", "url": "https://docs.x.ai/developers/models" }
1728                    ]
1729                }
1730            }),
1731            &mut state,
1732            &tx,
1733        );
1734
1735        match rx.try_recv().expect("start event") {
1736            ProviderStreamEvent::ProviderToolStarted(trace) => {
1737                assert_eq!(trace.id, "ws_123");
1738                assert_eq!(trace.name, "web_search");
1739            }
1740            other => panic!("unexpected event: {other:?}"),
1741        }
1742        match rx.try_recv().expect("completion event") {
1743            ProviderStreamEvent::ProviderToolCompleted(trace) => {
1744                assert_eq!(trace.id, "ws_123");
1745                assert_eq!(trace.name, "web_search");
1746                assert!(!trace.citations.is_empty());
1747            }
1748            other => panic!("unexpected event: {other:?}"),
1749        }
1750    }
1751
1752    #[test]
1753    fn xai_stream_parser_tolerates_raw_provider_tool_events() {
1754        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1755        let mut state = ResponsesStreamState::default();
1756
1757        handle_responses_stream_value(
1758            json!({
1759                "type": "response.web_search_call.in_progress",
1760                "item_id": "ws_raw_123",
1761                "query": "current events"
1762            }),
1763            &mut state,
1764            &tx,
1765        );
1766
1767        match rx.try_recv().expect("start event") {
1768            ProviderStreamEvent::ProviderToolStarted(trace) => {
1769                assert_eq!(trace.id, "ws_raw_123");
1770                assert_eq!(trace.name, "web_search");
1771                assert_eq!(trace.input["query"], "current events");
1772            }
1773            other => panic!("unexpected event: {other:?}"),
1774        }
1775    }
1776}