use llmsdk_provider::ImageModel;
use llmsdk_provider::image_model::ImageOptions;
use llmsdk_provider::language_model::FilePart;
use llmsdk_provider::shared::{FileBytes, FileData, ProviderOptions};
use llmsdk_xai::Xai;
use serde_json::{Value, json};
use wiremock::matchers::{body_partial_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const RED_PIXEL_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
fn provider(server: &MockServer) -> Xai {
Xai::builder()
.api_key("xai-test")
.base_url(server.uri())
.build()
.expect("provider builds")
}
fn provider_options_with_xai(value: &Value) -> ProviderOptions {
let mut po = ProviderOptions::new();
po.insert("xai".into(), value.as_object().cloned().unwrap());
po
}
fn ok_b64_response() -> Value {
json!({
"data": [{
"b64_json": RED_PIXEL_PNG_B64,
"revised_prompt": "A serene red square."
}],
"usage": { "cost_in_usd_ticks": 1234_u64 }
})
}
#[tokio::test]
async fn generations_happy_path_decodes_png_and_collects_metadata() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/images/generations"))
.and(header("authorization", "Bearer xai-test"))
.and(body_partial_json(json!({
"model": "grok-imagine-image",
"prompt": "a red square",
"response_format": "b64_json"
})))
.respond_with(ResponseTemplate::new(200).set_body_json(ok_b64_response()))
.mount(&server)
.await;
let model = provider(&server).image("grok-imagine-image");
let result = model
.do_generate(ImageOptions {
prompt: "a red square".into(),
..Default::default()
})
.await
.expect("ok");
assert_eq!(result.images.len(), 1);
let img = &result.images[0];
assert_eq!(&img.bytes[..8], b"\x89PNG\r\n\x1a\n");
assert_eq!(img.media_type, "image/png");
let pm = result.provider_metadata.expect("provider_metadata");
let xai = pm.get("xai").expect("xai entry");
let images = xai
.get("images")
.and_then(|v| v.as_array())
.expect("images array");
assert_eq!(images.len(), 1);
assert_eq!(images[0]["revisedPrompt"], "A serene red square.");
assert_eq!(xai["costInUsdTicks"], 1234);
let resp = result.response.expect("response info");
assert_eq!(resp.model_id.as_deref(), Some("grok-imagine-image"));
assert!(result.warnings.is_empty(), "no warnings expected");
}
#[tokio::test]
async fn provider_options_are_forwarded_to_wire() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/images/generations"))
.and(body_partial_json(json!({
"model": "grok-imagine-image",
"prompt": "a cat",
"response_format": "b64_json",
"aspect_ratio": "16:9",
"output_format": "png",
"sync_mode": true,
"resolution": "2k",
"quality": "high",
"user": "alice@example.com"
})))
.respond_with(ResponseTemplate::new(200).set_body_json(ok_b64_response()))
.mount(&server)
.await;
let model = provider(&server).image("grok-imagine-image");
let opts = ImageOptions {
prompt: "a cat".into(),
provider_options: Some(provider_options_with_xai(&json!({
"aspect_ratio": "16:9",
"output_format": "png",
"sync_mode": true,
"resolution": "2k",
"quality": "high",
"user": "alice@example.com"
}))),
..Default::default()
};
let result = model.do_generate(opts).await.expect("ok");
assert_eq!(result.images.len(), 1);
}
#[tokio::test]
async fn size_seed_mask_emit_unsupported_warnings_but_call_still_succeeds() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/images/generations"))
.respond_with(ResponseTemplate::new(200).set_body_json(ok_b64_response()))
.mount(&server)
.await;
let model = provider(&server).image("grok-imagine-image");
let opts = ImageOptions {
prompt: "hi".into(),
size: Some("1024x1024".into()),
seed: Some(7),
mask: Some(FilePart {
filename: None,
data: FileData::Data {
data: FileBytes::Bytes(vec![0xFF]),
},
media_type: "image/png".into(),
provider_options: None,
}),
..Default::default()
};
let result = model.do_generate(opts).await.expect("ok");
let settings: Vec<&str> = result
.warnings
.iter()
.filter_map(|w| match w {
llmsdk_provider::shared::Warning::Unsupported { feature, .. } => Some(feature.as_str()),
_ => None,
})
.collect();
assert!(settings.contains(&"size"));
assert!(settings.contains(&"seed"));
assert!(settings.contains(&"mask"));
}
#[tokio::test]
async fn files_route_to_edits_endpoint_with_single_image_field() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/images/edits"))
.and(body_partial_json(json!({
"model": "grok-imagine-image",
"prompt": "make it blue",
"response_format": "b64_json",
"image": { "url": "data:image/png;base64,Zm9v", "type": "image_url" }
})))
.respond_with(ResponseTemplate::new(200).set_body_json(ok_b64_response()))
.mount(&server)
.await;
let model = provider(&server).image("grok-imagine-image");
let opts = ImageOptions {
prompt: "make it blue".into(),
files: Some(vec![FilePart {
filename: None,
data: FileData::Data {
data: FileBytes::Bytes(b"foo".to_vec()),
},
media_type: "image/png".into(),
provider_options: None,
}]),
..Default::default()
};
let result = model.do_generate(opts).await.expect("ok");
assert_eq!(result.images.len(), 1);
}
#[tokio::test]
async fn url_only_response_triggers_download_fallback() {
let server = MockServer::start().await;
let download_path = "/cdn/red.png";
let download_url = format!("{}{download_path}", server.uri());
Mock::given(method("POST"))
.and(path("/images/generations"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"data": [{ "url": download_url }]
})))
.mount(&server)
.await;
let png_bytes: Vec<u8> = decode_b64_for_test(RED_PIXEL_PNG_B64);
Mock::given(method("GET"))
.and(path(download_path))
.respond_with(ResponseTemplate::new(200).set_body_raw(png_bytes.clone(), "image/png"))
.mount(&server)
.await;
let model = provider(&server).image("grok-imagine-image");
let result = model
.do_generate(ImageOptions {
prompt: "url me".into(),
..Default::default()
})
.await
.expect("ok");
assert_eq!(result.images.len(), 1);
let img = &result.images[0];
assert_eq!(img.bytes.as_ref(), png_bytes.as_slice());
assert_eq!(img.media_type, "image/png");
}
#[tokio::test]
async fn max_images_per_call_reports_three() {
let server = MockServer::start().await;
let model = provider(&server).image("grok-imagine-image-pro");
assert_eq!(model.max_images_per_call().await, Some(3));
assert_eq!(model.provider(), "xai");
assert_eq!(model.model_id(), "grok-imagine-image-pro");
}
fn decode_b64_for_test(input: &str) -> Vec<u8> {
fn dec(b: u8) -> Option<(u8, bool)> {
Some(match b {
b'A'..=b'Z' => (b - b'A', false),
b'a'..=b'z' => (b - b'a' + 26, false),
b'0'..=b'9' => (b - b'0' + 52, false),
b'+' => (62, false),
b'/' => (63, false),
b'=' => (0, true),
_ => return None,
})
}
let bytes = input.as_bytes();
assert!(bytes.len().is_multiple_of(4), "test base64 padded");
let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
for chunk in bytes.chunks_exact(4) {
let (b0, _) = dec(chunk[0]).expect("alphabet");
let (b1, _) = dec(chunk[1]).expect("alphabet");
let (b2, p2) = dec(chunk[2]).expect("alphabet");
let (b3, p3) = dec(chunk[3]).expect("alphabet");
let n =
(u32::from(b0) << 18) | (u32::from(b1) << 12) | (u32::from(b2) << 6) | u32::from(b3);
out.push(((n >> 16) & 0xFF) as u8);
if !p2 {
out.push(((n >> 8) & 0xFF) as u8);
}
if !p3 {
out.push((n & 0xFF) as u8);
}
}
out
}