1use async_trait::async_trait;
7use futures_util::StreamExt;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::{HashMap, HashSet};
12
13use crate::compatible::{AuthStyle, OpenAiCompatibleProvider};
14use crate::native::{
15 EditImageRequest, EditVideoRequest, ExtendVideoRequest, GenerateVideoRequest,
16 ImageToVideoRequest, MediaInputAsset, MediaOutputAsset, MediaOutputFormat,
17 ModelNativeCapabilities, NativeCapabilitiesProvider, NativeExecutionMode, NativeMediaJob,
18 NativeMediaJobStatus, NativeMediaRequest, NativeMediaResponse, NativeModelToolId,
19 NativeOperation, NativeToolSpec, ProviderNativeCapabilities, ProviderNativeModelToolSpec,
20 ReferenceToVideoRequest, media_input_schema,
21};
22use crate::traits::{
23 ChatMessage, ChatRequest, ChatResponse, ModelProvider, ProviderStreamEvent, ProviderToolTrace,
24 TokenUsage, ToolCall,
25};
26
27pub const XAI_DEFAULT_BASE_URL: &str = "https://api.x.ai/v1";
28
29pub struct XAiProvider {
30 api_key: Option<String>,
31 base_url: String,
32 chat: OpenAiCompatibleProvider,
33 client: Client,
34}
35
36#[derive(Debug, Serialize)]
37struct ImageGenerationRequest<'a> {
38 model: &'a str,
39 prompt: &'a str,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 n: Option<u32>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 response_format: Option<&'static str>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 aspect_ratio: Option<&'a str>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 resolution: Option<&'a str>,
48}
49
50#[derive(Debug, Serialize)]
51struct ImageEditRequest<'a> {
52 model: &'a str,
53 prompt: &'a str,
54 image: XaiMediaInput,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 response_format: Option<&'static str>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 aspect_ratio: Option<&'a str>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 resolution: Option<&'a str>,
61}
62
63#[derive(Debug, Serialize)]
64struct VideoRequest<'a> {
65 model: &'a str,
66 prompt: &'a str,
67 #[serde(rename = "duration", skip_serializing_if = "Option::is_none")]
68 duration_seconds: Option<u32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 aspect_ratio: Option<&'a str>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 resolution: Option<&'a str>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 image: Option<XaiMediaInput>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 reference_images: Option<Vec<XaiMediaInput>>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 video: Option<XaiMediaInput>,
79}
80
81#[derive(Debug, Serialize)]
82struct XaiMediaInput {
83 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
84 kind: Option<&'static str>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 url: Option<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 file_id: Option<String>,
89}
90
91#[derive(Debug, Deserialize)]
92struct ImageGenerationResponse {
93 data: Vec<ImageGenerationData>,
94}
95
96#[derive(Debug, Deserialize)]
97struct ImageGenerationData {
98 #[serde(default)]
99 url: Option<String>,
100 #[serde(default)]
101 b64_json: Option<String>,
102 #[serde(default)]
103 revised_prompt: Option<String>,
104}
105
106#[derive(Debug, Deserialize)]
107struct VideoStartResponse {
108 request_id: String,
109}
110
111#[derive(Debug, Deserialize)]
112struct VideoPollResponse {
113 status: String,
114 #[serde(default)]
115 video: Option<VideoAsset>,
116 #[serde(default)]
117 error: Option<XaiError>,
118}
119
120#[derive(Debug, Deserialize)]
121struct VideoAsset {
122 url: String,
123 #[serde(default)]
124 duration: Option<f64>,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128struct XaiError {
129 #[serde(default)]
130 code: Option<String>,
131 #[serde(default)]
132 message: Option<String>,
133}
134
135#[derive(Debug, Serialize)]
136struct ResponsesRequest {
137 model: String,
138 input: Vec<ResponsesInput>,
139 tools: Vec<ResponsesTool>,
140 temperature: f64,
141 stream: bool,
142}
143
144#[derive(Debug, Serialize)]
145#[serde(untagged)]
146enum ResponsesInput {
147 Message {
148 role: String,
149 content: String,
150 },
151 FunctionCall {
152 #[serde(rename = "type")]
153 kind: &'static str,
154 call_id: String,
155 name: String,
156 arguments: String,
157 },
158 FunctionCallOutput {
159 #[serde(rename = "type")]
160 kind: &'static str,
161 call_id: String,
162 output: String,
163 },
164}
165
166#[derive(Debug, Serialize, PartialEq)]
167struct ResponsesTool {
168 #[serde(rename = "type")]
169 kind: String,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 name: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 description: Option<String>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 parameters: Option<Value>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179struct ResponsesResponse {
180 #[serde(default)]
181 output: Vec<ResponsesOutput>,
182 #[serde(default)]
183 output_text: Option<String>,
184 #[serde(default)]
185 usage: Option<ResponsesUsage>,
186}
187
188#[derive(Debug, Clone, Deserialize)]
189struct ResponsesOutput {
190 #[serde(default)]
191 id: Option<String>,
192 #[serde(default)]
193 call_id: Option<String>,
194 #[serde(rename = "type", default)]
195 kind: Option<String>,
196 #[serde(default)]
197 name: Option<String>,
198 #[serde(default)]
199 arguments: Option<Value>,
200 #[serde(default)]
201 content: Vec<ResponsesContent>,
202 #[serde(default)]
203 status: Option<String>,
204 #[serde(flatten)]
205 extra: serde_json::Map<String, Value>,
206}
207
208#[derive(Debug, Clone, Deserialize, Serialize)]
209struct ResponsesContent {
210 #[serde(rename = "type", default)]
211 kind: Option<String>,
212 #[serde(default)]
213 text: Option<String>,
214 #[serde(default)]
215 annotations: Vec<Value>,
216}
217
218#[derive(Debug, Clone, Deserialize)]
219struct ResponsesUsage {
220 #[serde(default, alias = "prompt_tokens")]
221 input_tokens: u64,
222 #[serde(default, alias = "completion_tokens")]
223 output_tokens: u64,
224}
225
226fn xai_media_input(asset: MediaInputAsset, image_edit_input: bool) -> XaiMediaInput {
227 let kind = match &asset {
228 MediaInputAsset::ProviderFileId { .. } => None,
229 MediaInputAsset::Url { .. } | MediaInputAsset::DataUri { .. } => {
230 image_edit_input.then_some("image_url")
231 }
232 };
233 match asset {
234 MediaInputAsset::Url { url } => XaiMediaInput {
235 kind,
236 url: Some(url),
237 file_id: None,
238 },
239 MediaInputAsset::DataUri { data_uri } => XaiMediaInput {
240 kind,
241 url: Some(data_uri),
242 file_id: None,
243 },
244 MediaInputAsset::ProviderFileId { file_id } => XaiMediaInput {
245 kind,
246 url: None,
247 file_id: Some(file_id),
248 },
249 }
250}
251
252fn xai_image_tool_spec(operation: NativeOperation) -> NativeToolSpec {
253 let mut properties = json!({
254 "prompt": {"type": "string"},
255 "n": {"type": "integer", "minimum": 1},
256 "aspect_ratio": {
257 "type": "string",
258 "enum": [
259 "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3",
260 "2:1", "1:2", "19.5:9", "9:19.5", "20:9", "9:20", "auto"
261 ]
262 },
263 "resolution": {"type": "string", "enum": ["1k", "2k"]},
264 "output_format": {"type": "string", "enum": ["url", "base64"]},
265 "provider_options": {
266 "type": "object",
267 "properties": {},
268 "additionalProperties": false
269 }
270 });
271 let required = match operation {
272 NativeOperation::GenerateImage => vec!["prompt"],
273 NativeOperation::EditImage => {
274 properties["image"] = media_input_schema();
275 vec!["prompt", "image"]
276 }
277 other => panic!("unsupported xAI image operation {other:?}"),
278 };
279
280 NativeToolSpec {
281 capability: operation,
282 tool_name: operation.tool_name().unwrap().to_string(),
283 description: match operation {
284 NativeOperation::GenerateImage => {
285 "Generate an image with the configured xAI image model."
286 }
287 NativeOperation::EditImage => "Edit an image with the configured xAI image model.",
288 _ => unreachable!(),
289 }
290 .to_string(),
291 parameters_schema: json!({
292 "type": "object",
293 "properties": properties,
294 "required": required
295 }),
296 execution: NativeExecutionMode::Immediate,
297 }
298}
299
300fn xai_video_provider_options() -> Value {
301 json!({
302 "type": "object",
303 "properties": {
304 "poll_timeout_ms": {
305 "type": "integer",
306 "minimum": 1
307 }
308 },
309 "additionalProperties": false
310 })
311}
312
313fn xai_video_base_properties() -> Value {
314 json!({
315 "prompt": {"type": "string"},
316 "duration_seconds": {"type": "integer", "minimum": 1},
317 "aspect_ratio": {"type": "string", "enum": ["16:9", "9:16", "1:1"]},
318 "resolution": {"type": "string", "enum": ["480p", "720p"]},
319 "provider_options": xai_video_provider_options()
320 })
321}
322
323fn xai_video_tool_spec(operation: NativeOperation) -> NativeToolSpec {
324 let mut properties = xai_video_base_properties();
325 let required = match operation {
326 NativeOperation::GenerateVideo => vec!["prompt"],
327 NativeOperation::ImageToVideo => {
328 properties["image"] = media_input_schema();
329 vec!["prompt", "image"]
330 }
331 NativeOperation::ReferenceToVideo => {
332 properties["reference_images"] = json!({
333 "type": "array",
334 "items": media_input_schema(),
335 "minItems": 1,
336 "maxItems": 7
337 });
338 properties["duration_seconds"]["maximum"] = json!(10);
339 vec!["prompt", "reference_images"]
340 }
341 NativeOperation::EditVideo => {
342 properties = json!({
343 "prompt": {"type": "string"},
344 "video": media_input_schema(),
345 "provider_options": xai_video_provider_options()
346 });
347 vec!["prompt", "video"]
348 }
349 NativeOperation::ExtendVideo => {
350 properties = json!({
351 "prompt": {"type": "string"},
352 "video": media_input_schema(),
353 "duration_seconds": {
354 "type": "integer",
355 "minimum": 2,
356 "maximum": 10
357 },
358 "provider_options": xai_video_provider_options()
359 });
360 vec!["prompt", "video"]
361 }
362 other => panic!("unsupported xAI video operation {other:?}"),
363 };
364
365 NativeToolSpec {
366 capability: operation,
367 tool_name: operation.tool_name().unwrap().to_string(),
368 description: match operation {
369 NativeOperation::GenerateVideo => "Start an asynchronous xAI video generation job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call generate_video again for the same prompt unless the user explicitly asks for another independent video.",
370 NativeOperation::EditVideo => "Start an asynchronous xAI video editing job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call edit_video again for the same request unless the user explicitly asks for another independent edit.",
371 NativeOperation::ImageToVideo => "Start an asynchronous xAI image-to-video job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call image_to_video again for the same request unless the user explicitly asks for another independent video.",
372 NativeOperation::ReferenceToVideo => "Start an asynchronous xAI reference-to-video job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call reference_to_video again for the same request unless the user explicitly asks for another independent video.",
373 NativeOperation::ExtendVideo => "Start an asynchronous xAI video extension job. A successful call means the render was queued, not finished; use wait_operations with kind=media for the returned operation_id until it completes. Do not call extend_video again for the same request unless the user explicitly asks for another independent extension.",
374 _ => unreachable!(),
375 }
376 .to_string(),
377 parameters_schema: json!({
378 "type": "object",
379 "properties": properties,
380 "required": required
381 }),
382 execution: NativeExecutionMode::AsyncJob {
383 poll_supported: true,
384 },
385 }
386}
387
388fn xai_video_status(status: &str) -> anyhow::Result<NativeMediaJobStatus> {
389 match status {
390 "pending" => Ok(NativeMediaJobStatus::Running),
391 "done" => Ok(NativeMediaJobStatus::Completed),
392 "expired" => Ok(NativeMediaJobStatus::Expired),
393 "failed" => Ok(NativeMediaJobStatus::Failed),
394 other => anyhow::bail!("unknown xAI video job status '{other}'"),
395 }
396}
397
398fn first_nonempty(text: Option<&str>) -> Option<String> {
399 text.and_then(|value| {
400 let trimmed = value.trim();
401 if trimmed.is_empty() {
402 None
403 } else {
404 Some(trimmed.to_string())
405 }
406 })
407}
408
409fn xai_native_model_tool_specs() -> Vec<ProviderNativeModelToolSpec> {
410 vec![
411 ProviderNativeModelToolSpec {
412 id: NativeModelToolId::from("web_search"),
413 provider_type: "web_search".to_string(),
414 name: "web_search".to_string(),
415 description: "Provider-native xAI web search for current web results and citations."
416 .to_string(),
417 parameters_schema: Some(json!({
418 "type": "object",
419 "properties": {},
420 "additionalProperties": false
421 })),
422 config_schema: None,
423 },
424 ProviderNativeModelToolSpec {
425 id: NativeModelToolId::from("x_search"),
426 provider_type: "x_search".to_string(),
427 name: "x_search".to_string(),
428 description:
429 "Provider-native xAI X search for posts, discussions, and current activity on X."
430 .to_string(),
431 parameters_schema: Some(json!({
432 "type": "object",
433 "properties": {},
434 "additionalProperties": false
435 })),
436 config_schema: None,
437 },
438 ]
439}
440
441fn xai_native_model_tool_spec(tool_id: &NativeModelToolId) -> Option<ProviderNativeModelToolSpec> {
442 xai_native_model_tool_specs()
443 .into_iter()
444 .find(|spec| spec.id == *tool_id)
445}
446
447fn native_responses_tools(
448 native_tools: &[NativeModelToolId],
449 local_tools: Option<&[crate::ToolSpec]>,
450) -> anyhow::Result<Vec<ResponsesTool>> {
451 let mut tools = Vec::with_capacity(native_tools.len() + local_tools.map_or(0, <[_]>::len));
452 for tool_id in native_tools {
453 let tool = xai_native_model_tool_spec(tool_id)
454 .ok_or_else(|| anyhow::anyhow!("xAI does not support native model tool '{tool_id}'"))?;
455 tools.push(ResponsesTool {
456 kind: tool.provider_type,
457 name: None,
458 description: None,
459 parameters: None,
460 });
461 }
462
463 if let Some(local_tools) = local_tools {
464 tools.extend(local_tools.iter().map(|tool| ResponsesTool {
465 kind: "function".to_string(),
466 name: Some(crate::sanitize_tool_name(&tool.name)),
467 description: Some(tool.description.clone()),
468 parameters: Some(tool.parameters.clone()),
469 }));
470 }
471
472 Ok(tools)
473}
474
475fn responses_input(messages: &[ChatMessage]) -> Vec<ResponsesInput> {
476 let mut input = Vec::with_capacity(messages.len());
477
478 for message in messages {
479 if message.role == "assistant"
480 && let Ok(value) = serde_json::from_str::<Value>(&message.content)
481 && let Some(tool_calls_value) = value.get("tool_calls")
482 && let Ok(tool_calls) =
483 serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
484 {
485 if let Some(content) = value
486 .get("content")
487 .and_then(Value::as_str)
488 .and_then(|text| first_nonempty(Some(text)))
489 {
490 input.push(ResponsesInput::Message {
491 role: "assistant".to_string(),
492 content,
493 });
494 }
495
496 input.extend(
497 tool_calls
498 .into_iter()
499 .map(|call| ResponsesInput::FunctionCall {
500 kind: "function_call",
501 call_id: call.id,
502 name: call.name,
503 arguments: call.arguments,
504 }),
505 );
506 continue;
507 }
508
509 if message.role == "tool"
510 && let Ok(value) = serde_json::from_str::<Value>(&message.content)
511 && let Some(call_id) = value.get("tool_call_id").and_then(Value::as_str)
512 {
513 let output = value
514 .get("content")
515 .and_then(Value::as_str)
516 .unwrap_or_default()
517 .to_string();
518 input.push(ResponsesInput::FunctionCallOutput {
519 kind: "function_call_output",
520 call_id: call_id.to_string(),
521 output,
522 });
523 continue;
524 }
525
526 input.push(ResponsesInput::Message {
527 role: message.role.clone(),
528 content: message.content.clone(),
529 });
530 }
531
532 input
533}
534
535fn responses_text(response: &ResponsesResponse) -> Option<String> {
536 if let Some(text) = first_nonempty(response.output_text.as_deref()) {
537 return Some(text);
538 }
539
540 for item in &response.output {
541 for content in &item.content {
542 if content.kind.as_deref() == Some("output_text")
543 && let Some(text) = first_nonempty(content.text.as_deref())
544 {
545 return Some(text);
546 }
547 }
548 }
549
550 for item in &response.output {
551 for content in &item.content {
552 if let Some(text) = first_nonempty(content.text.as_deref()) {
553 return Some(text);
554 }
555 }
556 }
557
558 None
559}
560
561fn responses_tool_calls(response: &ResponsesResponse) -> Vec<ToolCall> {
562 response
563 .output
564 .iter()
565 .filter(|item| item.kind.as_deref() == Some("function_call"))
566 .filter_map(|item| {
567 let name = item.name.clone()?;
568 let arguments = match item.arguments.as_ref() {
569 Some(Value::String(value)) => value.clone(),
570 Some(value) => value.to_string(),
571 None => "{}".to_string(),
572 };
573 Some(ToolCall {
574 id: item
575 .call_id
576 .clone()
577 .or_else(|| item.id.clone())
578 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
579 name,
580 arguments,
581 })
582 })
583 .collect()
584}
585
586fn xai_native_tool_name(output_kind: &str) -> Option<&'static str> {
587 match output_kind {
588 "web_search_call" => Some("web_search"),
589 "x_search_call" => Some("x_search"),
590 "code_interpreter_call" => Some("code_interpreter"),
591 "file_search_call" => Some("file_search"),
592 "mcp_call" => Some("mcp"),
593 _ => None,
594 }
595}
596
597fn provider_tool_trace_from_responses_output(item: &ResponsesOutput) -> Option<ProviderToolTrace> {
598 let kind = item.kind.as_deref()?;
599 let name = xai_native_tool_name(kind)?;
600 let id = item
601 .call_id
602 .clone()
603 .or_else(|| item.id.clone())
604 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
605
606 let mut input = serde_json::Map::new();
607 input.insert(
608 "response_item_type".to_string(),
609 Value::String(kind.to_string()),
610 );
611 if let Some(status) = &item.status {
612 input.insert("status".to_string(), Value::String(status.clone()));
613 }
614 if let Some(arguments) = &item.arguments {
615 input.insert("arguments".to_string(), arguments.clone());
616 }
617 if let Some(name) = &item.name {
618 input.insert("name".to_string(), Value::String(name.clone()));
619 }
620 for key in [
621 "action",
622 "query",
623 "queries",
624 "server_label",
625 "server_url",
626 "vector_store_ids",
627 ] {
628 if let Some(value) = item.extra.get(key) {
629 input.insert(key.to_string(), value.clone());
630 }
631 }
632
633 let mut output = item.extra.clone();
634 output.remove("action");
635 output.remove("query");
636 output.remove("queries");
637 output.remove("server_label");
638 output.remove("server_url");
639 output.remove("vector_store_ids");
640 if !item.content.is_empty() {
641 output.insert(
642 "content".to_string(),
643 serde_json::to_value(&item.content).unwrap_or(Value::Null),
644 );
645 }
646
647 let mut citations = Vec::new();
648 for key in ["citations", "sources", "results"] {
649 if let Some(value) = item.extra.get(key) {
650 citations.push(value.clone());
651 }
652 }
653 for content in &item.content {
654 citations.extend(content.annotations.iter().cloned());
655 }
656
657 Some(ProviderToolTrace {
658 id,
659 name: name.to_string(),
660 provider: "xai".to_string(),
661 input: Value::Object(input),
662 output: (!output.is_empty()).then_some(Value::Object(output)),
663 citations,
664 })
665}
666
667fn responses_provider_tool_traces(response: &ResponsesResponse) -> Vec<ProviderToolTrace> {
668 response
669 .output
670 .iter()
671 .filter_map(provider_tool_trace_from_responses_output)
672 .collect()
673}
674
675#[derive(Default)]
676struct ResponsesStreamState {
677 text: String,
678 output: HashMap<String, ResponsesOutput>,
679 final_response: Option<ResponsesResponse>,
680 started_provider_tools: HashSet<String>,
681 completed_provider_tools: HashSet<String>,
682}
683
684impl ResponsesStreamState {
685 fn into_response(self) -> ResponsesResponse {
686 self.final_response.unwrap_or_else(|| ResponsesResponse {
687 output: self.output.into_values().collect(),
688 output_text: (!self.text.is_empty()).then_some(self.text),
689 usage: None,
690 })
691 }
692}
693
694fn stream_event_type(value: &Value) -> Option<&str> {
695 value.get("type").and_then(Value::as_str)
696}
697
698fn stream_text_delta(value: &Value) -> Option<&str> {
699 let kind = stream_event_type(value).unwrap_or_default();
700 if kind.contains("output_text.delta") || kind.contains("text.delta") {
701 return value.get("delta").and_then(Value::as_str);
702 }
703 None
704}
705
706fn stream_response(value: &Value) -> Option<ResponsesResponse> {
707 let kind = stream_event_type(value).unwrap_or_default();
708 if !(kind.ends_with(".completed") || kind == "response.completed") {
709 return None;
710 }
711 value
712 .get("response")
713 .cloned()
714 .and_then(|response| serde_json::from_value(response).ok())
715}
716
717fn stream_output_item(value: &Value) -> Option<ResponsesOutput> {
718 for key in ["item", "output_item", "response_item"] {
719 if let Some(item) = value.get(key)
720 && let Ok(output) = serde_json::from_value::<ResponsesOutput>(item.clone())
721 {
722 return Some(output);
723 }
724 }
725 serde_json::from_value::<ResponsesOutput>(value.clone()).ok()
726}
727
728fn stream_tool_phase(value: &Value, output: &ResponsesOutput) -> Option<&'static str> {
729 let kind = stream_event_type(value).unwrap_or_default();
730 if kind.contains(".added") || kind.contains(".in_progress") || kind.contains(".started") {
731 return Some("started");
732 }
733 if kind.contains(".done") || kind.contains(".completed") {
734 return Some("completed");
735 }
736 match output.status.as_deref() {
737 Some("in_progress" | "running" | "searching" | "started") => Some("started"),
738 Some("completed" | "done") => Some("completed"),
739 _ => None,
740 }
741}
742
743fn native_kind_from_stream_type(kind: &str) -> Option<&'static str> {
744 [
745 "web_search_call",
746 "x_search_call",
747 "code_interpreter_call",
748 "file_search_call",
749 "mcp_call",
750 ]
751 .into_iter()
752 .find(|candidate| kind.contains(candidate))
753}
754
755fn stream_raw_provider_tool_trace(value: &Value) -> Option<ProviderToolTrace> {
756 let kind = stream_event_type(value)?;
757 let response_item_type = native_kind_from_stream_type(kind)?;
758 let name = xai_native_tool_name(response_item_type)?;
759 let id = value
760 .get("call_id")
761 .or_else(|| value.get("item_id"))
762 .or_else(|| value.get("id"))
763 .and_then(Value::as_str)
764 .map(ToString::to_string)
765 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
766
767 let mut input = serde_json::Map::new();
768 input.insert(
769 "response_item_type".to_string(),
770 Value::String(response_item_type.to_string()),
771 );
772 input.insert(
773 "stream_event_type".to_string(),
774 Value::String(kind.to_string()),
775 );
776 if let Some(status) = value.get("status").and_then(Value::as_str) {
777 input.insert("status".to_string(), Value::String(status.to_string()));
778 }
779 for key in ["action", "query", "queries", "server_label", "server_url"] {
780 if let Some(field) = value.get(key) {
781 input.insert(key.to_string(), field.clone());
782 }
783 }
784
785 Some(ProviderToolTrace {
786 id,
787 name: name.to_string(),
788 provider: "xai".to_string(),
789 input: Value::Object(input),
790 output: None,
791 citations: Vec::new(),
792 })
793}
794
795fn stream_raw_provider_tool_phase(value: &Value) -> Option<&'static str> {
796 let kind = stream_event_type(value)?;
797 native_kind_from_stream_type(kind)?;
798 if kind.contains(".done") || kind.contains(".completed") {
799 Some("completed")
800 } else {
801 Some("started")
802 }
803}
804
805fn handle_responses_stream_value(
806 value: Value,
807 state: &mut ResponsesStreamState,
808 events: &tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
809) {
810 if let Some(delta) = stream_text_delta(&value)
811 && !delta.is_empty()
812 {
813 state.text.push_str(delta);
814 let _ = events.send(ProviderStreamEvent::TextDelta(delta.to_string()));
815 }
816
817 if let Some(response) = stream_response(&value) {
818 state.final_response = Some(response);
819 }
820
821 if let Some(output) = stream_output_item(&value)
822 && let Some(trace) = provider_tool_trace_from_responses_output(&output)
823 {
824 let phase = stream_tool_phase(&value, &output);
825 state.output.insert(trace.id.clone(), output);
826 match phase {
827 Some("started") => {
828 if state.started_provider_tools.insert(trace.id.clone()) {
829 let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace));
830 }
831 }
832 Some("completed") => {
833 if state.started_provider_tools.insert(trace.id.clone()) {
834 let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace.clone()));
835 }
836 if state.completed_provider_tools.insert(trace.id.clone()) {
837 let _ = events.send(ProviderStreamEvent::ProviderToolCompleted(trace));
838 }
839 }
840 _ => {}
841 }
842 return;
843 }
844
845 if let Some(trace) = stream_raw_provider_tool_trace(&value) {
846 match stream_raw_provider_tool_phase(&value) {
847 Some("completed") => {
848 if state.started_provider_tools.insert(trace.id.clone()) {
849 let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace.clone()));
850 }
851 if state.completed_provider_tools.insert(trace.id.clone()) {
852 let _ = events.send(ProviderStreamEvent::ProviderToolCompleted(trace));
853 }
854 }
855 Some("started") => {
856 if state.started_provider_tools.insert(trace.id.clone()) {
857 let _ = events.send(ProviderStreamEvent::ProviderToolStarted(trace));
858 }
859 }
860 _ => {}
861 }
862 }
863}
864
865impl XAiProvider {
866 pub fn new(api_key: Option<&str>) -> Self {
867 Self::with_base_url(api_key, XAI_DEFAULT_BASE_URL)
868 }
869
870 pub fn with_base_url(api_key: Option<&str>, base_url: &str) -> Self {
871 let normalized_base_url = base_url.trim_end_matches('/').to_string();
872 Self {
873 api_key: api_key.map(ToString::to_string),
874 base_url: normalized_base_url.clone(),
875 chat: OpenAiCompatibleProvider::new(
876 "xai",
877 &normalized_base_url,
878 api_key,
879 AuthStyle::Bearer,
880 ),
881 client: Client::builder()
882 .timeout(std::time::Duration::from_secs(120))
883 .connect_timeout(std::time::Duration::from_secs(10))
884 .build()
885 .unwrap_or_else(|_| Client::new()),
886 }
887 }
888
889 fn endpoint(&self, path: &str) -> String {
890 format!("{}/{}", self.base_url, path.trim_start_matches('/'))
891 }
892
893 fn api_key(&self) -> anyhow::Result<&str> {
894 self.api_key.as_deref().ok_or_else(|| {
895 anyhow::anyhow!("xAI API key not set. Set XAI_API_KEY or edit config.toml.")
896 })
897 }
898
899 async fn chat_with_native_model_tools(
900 &self,
901 request: ChatRequest<'_>,
902 model: &str,
903 temperature: f64,
904 native_tools: &[NativeModelToolId],
905 ) -> anyhow::Result<ChatResponse> {
906 let api_key = self.api_key()?;
907 let body = ResponsesRequest {
908 model: model.to_string(),
909 input: responses_input(request.messages),
910 tools: native_responses_tools(native_tools, request.tools)?,
911 temperature,
912 stream: false,
913 };
914
915 let response = self
916 .client
917 .post(self.endpoint("/responses"))
918 .header("Authorization", format!("Bearer {api_key}"))
919 .json(&body)
920 .send()
921 .await?;
922
923 if !response.status().is_success() {
924 return Err(crate::api_error("xAI", response).await);
925 }
926
927 let body_text = response.text().await?;
928 let response: ResponsesResponse = serde_json::from_str(&body_text).map_err(|error| {
929 anyhow::anyhow!(
930 "xAI Responses API decode error: {error}\nBody: {}",
931 &body_text[..body_text.len().min(500)]
932 )
933 })?;
934
935 let usage = response
936 .usage
937 .as_ref()
938 .map(|usage| TokenUsage {
939 input_tokens: usage.input_tokens,
940 output_tokens: usage.output_tokens,
941 })
942 .unwrap_or_default();
943 let text = responses_text(&response);
944 let tool_calls = responses_tool_calls(&response);
945 let provider_tool_calls = responses_provider_tool_traces(&response);
946
947 Ok(ChatResponse {
948 text,
949 tool_calls,
950 provider_tool_calls,
951 usage,
952 })
953 }
954
955 async fn chat_with_native_model_tools_streaming(
956 &self,
957 request: ChatRequest<'_>,
958 model: &str,
959 temperature: f64,
960 native_tools: &[NativeModelToolId],
961 events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
962 ) -> anyhow::Result<ChatResponse> {
963 let api_key = self.api_key()?;
964 let body = ResponsesRequest {
965 model: model.to_string(),
966 input: responses_input(request.messages),
967 tools: native_responses_tools(native_tools, request.tools)?,
968 temperature,
969 stream: true,
970 };
971 let response = self
972 .client
973 .post(self.endpoint("/responses"))
974 .header("Authorization", format!("Bearer {api_key}"))
975 .json(&body)
976 .send()
977 .await?;
978
979 if !response.status().is_success() {
980 return Err(crate::api_error("xAI", response).await);
981 }
982
983 let mut state = ResponsesStreamState::default();
984 let mut stream = response.bytes_stream();
985 let mut buffer = String::new();
986
987 while let Some(chunk) = stream.next().await {
988 let chunk = chunk?;
989 buffer.push_str(&String::from_utf8_lossy(&chunk));
990 if buffer.contains("\r\n") {
991 buffer = buffer.replace("\r\n", "\n");
992 }
993
994 while let Some(split_at) = buffer.find("\n\n") {
995 let frame = buffer[..split_at].to_string();
996 buffer = buffer[split_at + 2..].to_string();
997
998 for line in frame.lines() {
999 let Some(data) = line.strip_prefix("data:") else {
1000 continue;
1001 };
1002 let data = data.trim();
1003 if data.is_empty() || data == "[DONE]" {
1004 continue;
1005 }
1006 if let Ok(value) = serde_json::from_str::<Value>(data) {
1007 handle_responses_stream_value(value, &mut state, &events);
1008 }
1009 }
1010 }
1011 }
1012
1013 if !buffer.trim().is_empty() {
1014 for line in buffer.lines() {
1015 let Some(data) = line.strip_prefix("data:") else {
1016 continue;
1017 };
1018 let data = data.trim();
1019 if data.is_empty() || data == "[DONE]" {
1020 continue;
1021 }
1022 if let Ok(value) = serde_json::from_str::<Value>(data) {
1023 handle_responses_stream_value(value, &mut state, &events);
1024 }
1025 }
1026 }
1027
1028 let response = state.into_response();
1029 let usage = response
1030 .usage
1031 .as_ref()
1032 .map(|usage| TokenUsage {
1033 input_tokens: usage.input_tokens,
1034 output_tokens: usage.output_tokens,
1035 })
1036 .unwrap_or_default();
1037 let text = responses_text(&response);
1038 let tool_calls = responses_tool_calls(&response);
1039 let provider_tool_calls = responses_provider_tool_traces(&response);
1040
1041 Ok(ChatResponse {
1042 text,
1043 tool_calls,
1044 provider_tool_calls,
1045 usage,
1046 })
1047 }
1048
1049 async fn generate_image(
1050 &self,
1051 request: crate::native::GenerateImageRequest,
1052 ) -> anyhow::Result<NativeMediaResponse> {
1053 let api_key = self.api_key()?;
1054
1055 let response_format = match request.output_format {
1056 MediaOutputFormat::Url => None,
1057 MediaOutputFormat::Base64 => Some("b64_json"),
1058 };
1059 let body = ImageGenerationRequest {
1060 model: &request.model,
1061 prompt: &request.prompt,
1062 n: request.n,
1063 response_format,
1064 aspect_ratio: request.aspect_ratio.as_deref(),
1065 resolution: request.resolution.as_deref(),
1066 };
1067
1068 let response = self
1069 .client
1070 .post(self.endpoint("/images/generations"))
1071 .header("Authorization", format!("Bearer {api_key}"))
1072 .json(&body)
1073 .send()
1074 .await?;
1075
1076 if !response.status().is_success() {
1077 return Err(crate::api_error("xAI", response).await);
1078 }
1079
1080 let images: ImageGenerationResponse = response.json().await?;
1081 let mut assets = Vec::new();
1082 let mut revised_prompts = Vec::new();
1083
1084 for image in images.data {
1085 if let Some(prompt) = image.revised_prompt {
1086 revised_prompts.push(prompt);
1087 }
1088 if let Some(url) = image.url {
1089 assets.push(MediaOutputAsset::Url {
1090 url,
1091 mime_type: Some("image/jpeg".to_string()),
1092 });
1093 } else if let Some(data) = image.b64_json {
1094 assets.push(MediaOutputAsset::Base64 {
1095 data,
1096 mime_type: Some("image/jpeg".to_string()),
1097 });
1098 }
1099 }
1100
1101 if assets.is_empty() {
1102 anyhow::bail!("xAI image generation returned no assets");
1103 }
1104
1105 let metadata = if revised_prompts.is_empty() {
1106 None
1107 } else {
1108 Some(serde_json::json!({ "revised_prompts": revised_prompts }))
1109 };
1110
1111 Ok(NativeMediaResponse::Assets { assets, metadata })
1112 }
1113
1114 async fn edit_image(&self, request: EditImageRequest) -> anyhow::Result<NativeMediaResponse> {
1115 let api_key = self.api_key()?;
1116 let response_format = match request.output_format {
1117 MediaOutputFormat::Url => None,
1118 MediaOutputFormat::Base64 => Some("b64_json"),
1119 };
1120 let body = ImageEditRequest {
1121 model: &request.model,
1122 prompt: &request.prompt,
1123 image: xai_media_input(request.image, true),
1124 response_format,
1125 aspect_ratio: request.aspect_ratio.as_deref(),
1126 resolution: request.resolution.as_deref(),
1127 };
1128
1129 let response = self
1130 .client
1131 .post(self.endpoint("/images/edits"))
1132 .header("Authorization", format!("Bearer {api_key}"))
1133 .json(&body)
1134 .send()
1135 .await?;
1136
1137 if !response.status().is_success() {
1138 return Err(crate::api_error("xAI", response).await);
1139 }
1140
1141 self.parse_image_response(response).await
1142 }
1143
1144 async fn parse_image_response(
1145 &self,
1146 response: reqwest::Response,
1147 ) -> anyhow::Result<NativeMediaResponse> {
1148 let images: ImageGenerationResponse = response.json().await?;
1149 let mut assets = Vec::new();
1150 let mut revised_prompts = Vec::new();
1151
1152 for image in images.data {
1153 if let Some(prompt) = image.revised_prompt {
1154 revised_prompts.push(prompt);
1155 }
1156 if let Some(url) = image.url {
1157 assets.push(MediaOutputAsset::Url {
1158 url,
1159 mime_type: Some("image/jpeg".to_string()),
1160 });
1161 } else if let Some(data) = image.b64_json {
1162 assets.push(MediaOutputAsset::Base64 {
1163 data,
1164 mime_type: Some("image/jpeg".to_string()),
1165 });
1166 }
1167 }
1168
1169 if assets.is_empty() {
1170 anyhow::bail!("xAI image operation returned no assets");
1171 }
1172
1173 let metadata = if revised_prompts.is_empty() {
1174 None
1175 } else {
1176 Some(json!({ "revised_prompts": revised_prompts }))
1177 };
1178
1179 Ok(NativeMediaResponse::Assets { assets, metadata })
1180 }
1181
1182 async fn start_video_job<T: Serialize + ?Sized>(
1183 &self,
1184 path: &str,
1185 operation: NativeOperation,
1186 model: &str,
1187 body: &T,
1188 ) -> anyhow::Result<NativeMediaResponse> {
1189 let api_key = self.api_key()?;
1190 let response = self
1191 .client
1192 .post(self.endpoint(path))
1193 .header("Authorization", format!("Bearer {api_key}"))
1194 .json(body)
1195 .send()
1196 .await?;
1197
1198 if !response.status().is_success() {
1199 return Err(crate::api_error("xAI", response).await);
1200 }
1201
1202 let started: VideoStartResponse = response.json().await?;
1203 Ok(NativeMediaResponse::Job {
1204 job: NativeMediaJob {
1205 provider: "xai".to_string(),
1206 operation,
1207 job_id: started.request_id,
1208 status: NativeMediaJobStatus::Queued,
1209 model: Some(model.to_string()),
1210 metadata: None,
1211 },
1212 })
1213 }
1214
1215 async fn generate_video(
1216 &self,
1217 request: GenerateVideoRequest,
1218 ) -> anyhow::Result<NativeMediaResponse> {
1219 let body = VideoRequest {
1220 model: &request.model,
1221 prompt: &request.prompt,
1222 duration_seconds: request.duration_seconds,
1223 aspect_ratio: request.aspect_ratio.as_deref(),
1224 resolution: request.resolution.as_deref(),
1225 image: None,
1226 reference_images: None,
1227 video: None,
1228 };
1229 self.start_video_job(
1230 "/videos/generations",
1231 NativeOperation::GenerateVideo,
1232 &request.model,
1233 &body,
1234 )
1235 .await
1236 }
1237
1238 async fn image_to_video(
1239 &self,
1240 request: ImageToVideoRequest,
1241 ) -> anyhow::Result<NativeMediaResponse> {
1242 let body = VideoRequest {
1243 model: &request.model,
1244 prompt: &request.prompt,
1245 duration_seconds: request.duration_seconds,
1246 aspect_ratio: request.aspect_ratio.as_deref(),
1247 resolution: request.resolution.as_deref(),
1248 image: Some(xai_media_input(request.image, false)),
1249 reference_images: None,
1250 video: None,
1251 };
1252 self.start_video_job(
1253 "/videos/generations",
1254 NativeOperation::ImageToVideo,
1255 &request.model,
1256 &body,
1257 )
1258 .await
1259 }
1260
1261 async fn reference_to_video(
1262 &self,
1263 request: ReferenceToVideoRequest,
1264 ) -> anyhow::Result<NativeMediaResponse> {
1265 let body = VideoRequest {
1266 model: &request.model,
1267 prompt: &request.prompt,
1268 duration_seconds: request.duration_seconds,
1269 aspect_ratio: request.aspect_ratio.as_deref(),
1270 resolution: request.resolution.as_deref(),
1271 image: None,
1272 reference_images: Some(
1273 request
1274 .reference_images
1275 .into_iter()
1276 .map(|asset| xai_media_input(asset, false))
1277 .collect(),
1278 ),
1279 video: None,
1280 };
1281 self.start_video_job(
1282 "/videos/generations",
1283 NativeOperation::ReferenceToVideo,
1284 &request.model,
1285 &body,
1286 )
1287 .await
1288 }
1289
1290 async fn edit_video(&self, request: EditVideoRequest) -> anyhow::Result<NativeMediaResponse> {
1291 let body = VideoRequest {
1292 model: &request.model,
1293 prompt: &request.prompt,
1294 duration_seconds: None,
1295 aspect_ratio: None,
1296 resolution: None,
1297 image: None,
1298 reference_images: None,
1299 video: Some(xai_media_input(request.video, false)),
1300 };
1301 self.start_video_job(
1302 "/videos/edits",
1303 NativeOperation::EditVideo,
1304 &request.model,
1305 &body,
1306 )
1307 .await
1308 }
1309
1310 async fn extend_video(
1311 &self,
1312 request: ExtendVideoRequest,
1313 ) -> anyhow::Result<NativeMediaResponse> {
1314 let body = VideoRequest {
1315 model: &request.model,
1316 prompt: &request.prompt,
1317 duration_seconds: request.duration_seconds,
1318 aspect_ratio: None,
1319 resolution: None,
1320 image: None,
1321 reference_images: None,
1322 video: Some(xai_media_input(request.video, false)),
1323 };
1324 self.start_video_job(
1325 "/videos/extensions",
1326 NativeOperation::ExtendVideo,
1327 &request.model,
1328 &body,
1329 )
1330 .await
1331 }
1332}
1333
1334#[async_trait]
1335impl ModelProvider for XAiProvider {
1336 async fn chat(
1337 &self,
1338 request: ChatRequest<'_>,
1339 model: &str,
1340 temperature: f64,
1341 ) -> anyhow::Result<ChatResponse> {
1342 if let Some(native_tools) = request.native_tools
1343 && !native_tools.is_empty()
1344 {
1345 return self
1346 .chat_with_native_model_tools(request, model, temperature, native_tools)
1347 .await;
1348 }
1349 self.chat.chat(request, model, temperature).await
1350 }
1351
1352 async fn chat_stream(
1353 &self,
1354 request: ChatRequest<'_>,
1355 model: &str,
1356 temperature: f64,
1357 events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1358 ) -> anyhow::Result<ChatResponse> {
1359 if let Some(native_tools) = request.native_tools
1360 && !native_tools.is_empty()
1361 {
1362 return self
1363 .chat_with_native_model_tools_streaming(
1364 request,
1365 model,
1366 temperature,
1367 native_tools,
1368 events,
1369 )
1370 .await;
1371 }
1372 self.chat
1373 .chat_stream(request, model, temperature, events)
1374 .await
1375 }
1376
1377 fn context_window(&self, model: &str) -> Option<usize> {
1378 self.chat.context_window(model)
1379 }
1380
1381 fn supports_native_tools(&self) -> bool {
1382 true
1383 }
1384
1385 fn supports_developer_role(&self, model: &str) -> bool {
1386 self.chat.supports_developer_role(model)
1387 }
1388
1389 fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
1390 Some(NativeCapabilitiesProvider::native_capabilities(self))
1391 }
1392
1393 async fn submit_media(
1394 &self,
1395 request: NativeMediaRequest,
1396 ) -> anyhow::Result<NativeMediaResponse> {
1397 NativeCapabilitiesProvider::submit_media(self, request).await
1398 }
1399
1400 async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
1401 NativeCapabilitiesProvider::poll_media_job(self, job).await
1402 }
1403
1404 async fn warmup(&self) -> anyhow::Result<()> {
1405 self.chat.warmup().await
1406 }
1407}
1408
1409#[async_trait]
1410impl NativeCapabilitiesProvider for XAiProvider {
1411 fn native_capabilities(&self) -> ProviderNativeCapabilities {
1412 ProviderNativeCapabilities {
1413 provider: "xai".to_string(),
1414 model_tools: xai_native_model_tool_specs(),
1415 models: vec![
1416 ModelNativeCapabilities {
1417 model_pattern: "grok-imagine-image-quality".to_string(),
1418 tools: vec![
1419 xai_image_tool_spec(NativeOperation::GenerateImage),
1420 xai_image_tool_spec(NativeOperation::EditImage),
1421 ],
1422 },
1423 ModelNativeCapabilities {
1424 model_pattern: "grok-imagine-video*".to_string(),
1425 tools: vec![
1426 xai_video_tool_spec(NativeOperation::GenerateVideo),
1427 xai_video_tool_spec(NativeOperation::EditVideo),
1428 xai_video_tool_spec(NativeOperation::ImageToVideo),
1429 xai_video_tool_spec(NativeOperation::ReferenceToVideo),
1430 xai_video_tool_spec(NativeOperation::ExtendVideo),
1431 ],
1432 },
1433 ],
1434 }
1435 }
1436
1437 async fn submit_media(
1438 &self,
1439 request: NativeMediaRequest,
1440 ) -> anyhow::Result<NativeMediaResponse> {
1441 let operation = request.operation();
1442 match request {
1443 NativeMediaRequest::GenerateImage(request) => self.generate_image(request).await,
1444 NativeMediaRequest::EditImage(request) => self.edit_image(request).await,
1445 NativeMediaRequest::GenerateVideo(request) => self.generate_video(request).await,
1446 NativeMediaRequest::EditVideo(request) => self.edit_video(request).await,
1447 NativeMediaRequest::ImageToVideo(request) => self.image_to_video(request).await,
1448 NativeMediaRequest::ReferenceToVideo(request) => self.reference_to_video(request).await,
1449 NativeMediaRequest::ExtendVideo(request) => self.extend_video(request).await,
1450 NativeMediaRequest::GenerateSpeech(_) | NativeMediaRequest::TranscribeAudio(_) => {
1451 anyhow::bail!(
1452 "xAI native operation {operation:?} is declared but not implemented in this pass"
1453 )
1454 }
1455 }
1456 }
1457
1458 async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
1459 let api_key = self.api_key()?;
1460 let response = self
1461 .client
1462 .get(self.endpoint(format!("/videos/{}", job.job_id).as_str()))
1463 .header("Authorization", format!("Bearer {api_key}"))
1464 .send()
1465 .await?;
1466
1467 if !response.status().is_success() {
1468 return Err(crate::api_error("xAI", response).await);
1469 }
1470
1471 let polled: VideoPollResponse = response.json().await?;
1472 let status = xai_video_status(&polled.status)?;
1473 if status == NativeMediaJobStatus::Completed {
1474 let video = polled.video.ok_or_else(|| {
1475 anyhow::anyhow!("xAI video job {} completed without a video", job.job_id)
1476 })?;
1477 let metadata = video
1478 .duration
1479 .map(|duration| json!({ "duration_seconds": duration }));
1480 return Ok(NativeMediaResponse::Assets {
1481 assets: vec![MediaOutputAsset::Url {
1482 url: video.url,
1483 mime_type: Some("video/mp4".to_string()),
1484 }],
1485 metadata,
1486 });
1487 }
1488
1489 let metadata = polled
1490 .error
1491 .and_then(|error| serde_json::to_value(error).ok());
1492 Ok(NativeMediaResponse::Job {
1493 job: NativeMediaJob {
1494 provider: job.provider.clone(),
1495 operation: job.operation,
1496 job_id: job.job_id.clone(),
1497 status,
1498 model: job.model.clone(),
1499 metadata,
1500 },
1501 })
1502 }
1503}
1504
1505#[cfg(test)]
1506mod tests {
1507 use super::*;
1508
1509 #[test]
1510 fn creates_with_default_base_url() {
1511 let provider = XAiProvider::new(Some("xai-key"));
1512 assert_eq!(provider.base_url, XAI_DEFAULT_BASE_URL);
1513 }
1514
1515 #[test]
1516 fn capabilities_include_xai_video_modes() {
1517 let provider = XAiProvider::new(None);
1518 let capabilities = NativeCapabilitiesProvider::native_capabilities(&provider);
1519 let video = capabilities
1520 .models
1521 .iter()
1522 .find(|model| model.model_pattern == "grok-imagine-video*")
1523 .expect("video capability");
1524
1525 assert!(
1526 video
1527 .operations()
1528 .any(|op| op == NativeOperation::ImageToVideo)
1529 );
1530 assert!(
1531 video
1532 .operations()
1533 .any(|op| op == NativeOperation::ReferenceToVideo)
1534 );
1535 assert!(
1536 video
1537 .operations()
1538 .any(|op| op == NativeOperation::ExtendVideo)
1539 );
1540 }
1541
1542 #[test]
1543 fn xai_video_status_maps_to_native_status() {
1544 assert_eq!(
1545 xai_video_status("pending").expect("pending"),
1546 NativeMediaJobStatus::Running
1547 );
1548 assert_eq!(
1549 xai_video_status("done").expect("done"),
1550 NativeMediaJobStatus::Completed
1551 );
1552 assert_eq!(
1553 xai_video_status("expired").expect("expired"),
1554 NativeMediaJobStatus::Expired
1555 );
1556 assert_eq!(
1557 xai_video_status("failed").expect("failed"),
1558 NativeMediaJobStatus::Failed
1559 );
1560 }
1561
1562 #[test]
1563 fn xai_video_poll_response_matches_rest_done_shape() {
1564 let response: VideoPollResponse = serde_json::from_value(json!({
1565 "status": "done",
1566 "video": {
1567 "url": "https://vidgen.x.ai/example/video.mp4",
1568 "duration": 8,
1569 "respect_moderation": true
1570 },
1571 "model": "grok-imagine-video"
1572 }))
1573 .expect("poll response should parse");
1574
1575 assert_eq!(response.status, "done");
1576 let video = response.video.expect("video asset");
1577 assert_eq!(video.url, "https://vidgen.x.ai/example/video.mp4");
1578 assert_eq!(video.duration, Some(8.0));
1579 }
1580
1581 #[test]
1582 fn xai_image_edit_input_uses_image_url_shape() {
1583 let input = xai_media_input(
1584 MediaInputAsset::Url {
1585 url: "https://example.com/image.png".to_string(),
1586 },
1587 true,
1588 );
1589 let value = serde_json::to_value(input).expect("serialize");
1590
1591 assert_eq!(value["type"], "image_url");
1592 assert_eq!(value["url"], "https://example.com/image.png");
1593 }
1594
1595 #[test]
1596 fn xai_responses_tools_include_native_and_local_tools() {
1597 let tools = native_responses_tools(
1598 &[
1599 NativeModelToolId::from("web_search"),
1600 NativeModelToolId::from("x_search"),
1601 ],
1602 Some(&[crate::ToolSpec {
1603 name: "shell".to_string(),
1604 description: "Run a shell command.".to_string(),
1605 parameters: json!({
1606 "type": "object",
1607 "properties": {
1608 "cmd": { "type": "string" }
1609 },
1610 "required": ["cmd"]
1611 }),
1612 category: crate::ToolCategory::Write,
1613 }]),
1614 )
1615 .expect("supported tools");
1616
1617 assert_eq!(tools[0].kind, "web_search");
1618 assert_eq!(tools[1].kind, "x_search");
1619 assert_eq!(tools[2].kind, "function");
1620 assert_eq!(tools[2].name.as_deref(), Some("shell"));
1621 }
1622
1623 #[test]
1624 fn xai_responses_tools_reject_unknown_native_tool_ids() {
1625 let error = native_responses_tools(&[NativeModelToolId::from("unknown_tool")], None)
1626 .expect_err("unsupported tool should fail");
1627
1628 assert!(error.to_string().contains("unknown_tool"));
1629 }
1630
1631 #[test]
1632 fn xai_responses_extracts_function_calls() {
1633 let response: ResponsesResponse = serde_json::from_value(json!({
1634 "output": [
1635 {
1636 "type": "function_call",
1637 "call_id": "call_123",
1638 "name": "shell",
1639 "arguments": "{\"cmd\":\"date\"}"
1640 }
1641 ],
1642 "usage": {
1643 "input_tokens": 5,
1644 "output_tokens": 3
1645 }
1646 }))
1647 .expect("responses payload should parse");
1648
1649 let calls = responses_tool_calls(&response);
1650 assert_eq!(calls.len(), 1);
1651 assert_eq!(calls[0].id, "call_123");
1652 assert_eq!(calls[0].name, "shell");
1653 assert_eq!(calls[0].arguments, "{\"cmd\":\"date\"}");
1654 }
1655
1656 #[test]
1657 fn xai_responses_extracts_provider_native_tool_traces() {
1658 let response: ResponsesResponse = serde_json::from_value(json!({
1659 "output": [
1660 {
1661 "id": "ws_123",
1662 "type": "web_search_call",
1663 "status": "completed",
1664 "action": {
1665 "type": "search",
1666 "query": "latest xAI models"
1667 },
1668 "results": [
1669 { "title": "xAI Docs", "url": "https://docs.x.ai/developers/models" }
1670 ]
1671 },
1672 {
1673 "type": "message",
1674 "content": [
1675 {
1676 "type": "output_text",
1677 "text": "xAI has new models.",
1678 "annotations": [
1679 { "type": "url_citation", "url": "https://docs.x.ai/developers/models" }
1680 ]
1681 }
1682 ]
1683 }
1684 ]
1685 }))
1686 .expect("responses payload should parse");
1687
1688 let traces = responses_provider_tool_traces(&response);
1689 assert_eq!(traces.len(), 1);
1690 assert_eq!(traces[0].id, "ws_123");
1691 assert_eq!(traces[0].name, "web_search");
1692 assert_eq!(traces[0].provider, "xai");
1693 assert_eq!(traces[0].input["status"], "completed");
1694 assert!(traces[0].output.is_some());
1695 assert_eq!(traces[0].citations.len(), 1);
1696 }
1697
1698 #[test]
1699 fn xai_stream_parser_emits_provider_tool_start_and_completion() {
1700 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1701 let mut state = ResponsesStreamState::default();
1702
1703 handle_responses_stream_value(
1704 json!({
1705 "type": "response.output_item.added",
1706 "item": {
1707 "id": "ws_123",
1708 "type": "web_search_call",
1709 "status": "in_progress",
1710 "action": {
1711 "type": "search",
1712 "query": "latest xAI models"
1713 }
1714 }
1715 }),
1716 &mut state,
1717 &tx,
1718 );
1719 handle_responses_stream_value(
1720 json!({
1721 "type": "response.output_item.done",
1722 "item": {
1723 "id": "ws_123",
1724 "type": "web_search_call",
1725 "status": "completed",
1726 "results": [
1727 { "title": "xAI Docs", "url": "https://docs.x.ai/developers/models" }
1728 ]
1729 }
1730 }),
1731 &mut state,
1732 &tx,
1733 );
1734
1735 match rx.try_recv().expect("start event") {
1736 ProviderStreamEvent::ProviderToolStarted(trace) => {
1737 assert_eq!(trace.id, "ws_123");
1738 assert_eq!(trace.name, "web_search");
1739 }
1740 other => panic!("unexpected event: {other:?}"),
1741 }
1742 match rx.try_recv().expect("completion event") {
1743 ProviderStreamEvent::ProviderToolCompleted(trace) => {
1744 assert_eq!(trace.id, "ws_123");
1745 assert_eq!(trace.name, "web_search");
1746 assert!(!trace.citations.is_empty());
1747 }
1748 other => panic!("unexpected event: {other:?}"),
1749 }
1750 }
1751
1752 #[test]
1753 fn xai_stream_parser_tolerates_raw_provider_tool_events() {
1754 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1755 let mut state = ResponsesStreamState::default();
1756
1757 handle_responses_stream_value(
1758 json!({
1759 "type": "response.web_search_call.in_progress",
1760 "item_id": "ws_raw_123",
1761 "query": "current events"
1762 }),
1763 &mut state,
1764 &tx,
1765 );
1766
1767 match rx.try_recv().expect("start event") {
1768 ProviderStreamEvent::ProviderToolStarted(trace) => {
1769 assert_eq!(trace.id, "ws_raw_123");
1770 assert_eq!(trace.name, "web_search");
1771 assert_eq!(trace.input["query"], "current events");
1772 }
1773 other => panic!("unexpected event: {other:?}"),
1774 }
1775 }
1776}