use super::{runtime::ContentPart, ProviderError, ProviderKind};
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use reqwest::ClientBuilder;
use serde::{Deserialize, Serialize};
use std::time::Duration;
const DEFAULT_IMAGE_SIZE: &str = "1024x1024";
#[derive(Serialize, Debug)]
struct Request {
model: String,
messages: Vec<RequestMessage>,
temperature: f64,
response_format: serde_json::Value,
}
#[derive(Serialize, Debug)]
struct RequestMessage {
role: String,
content: Vec<RequestContentPart>,
}
#[derive(Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
enum RequestContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
File { file: FileInput },
}
#[derive(Serialize, Debug)]
struct ImageUrl {
url: String,
}
#[derive(Serialize, Debug)]
struct FileInput {
filename: String,
file_data: String,
}
#[derive(Deserialize, Debug)]
struct ResponseMessage {
#[allow(dead_code)]
role: String,
content: String,
}
#[derive(Deserialize, Debug)]
struct Choice {
message: ResponseMessage,
}
#[derive(Deserialize, Debug)]
struct Response {
choices: Vec<Choice>,
}
#[derive(Serialize, Debug)]
struct ImageGenerationRequest {
model: String,
prompt: String,
n: u8,
size: String,
response_format: String,
}
#[derive(Deserialize, Debug)]
struct ImageGenerationResponse {
data: Vec<ImageGenerationData>,
}
#[derive(Deserialize, Debug)]
struct ImageGenerationData {
b64_json: String,
}
fn normalize_response_format(response_format: serde_json::Value) -> serde_json::Value {
if response_format
.get("type")
.and_then(serde_json::Value::as_str)
== Some("json_schema")
{
return response_format;
}
serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": response_format,
"strict": true
}
})
}
pub async fn send_request(
url: &String,
model: &String,
content_parts: &[ContentPart],
timeout_in_sec: u64,
response_format: serde_json::Value,
) -> Result<String, ProviderError> {
let request = Request {
model: model.clone(),
messages: vec![RequestMessage {
role: "user".to_string(),
content: request_content_parts(content_parts),
}],
temperature: super::DEFAULT_TEMPERATURE,
response_format: normalize_response_format(response_format),
};
let client = ClientBuilder::new()
.timeout(Duration::from_secs(timeout_in_sec))
.build()
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
let http_resp = client
.post(url)
.json(&request)
.send()
.await
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
let status = http_resp.status();
let body_bytes = http_resp
.bytes()
.await
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
if !status.is_success() {
let raw = String::from_utf8_lossy(&body_bytes);
return Err(ProviderError::from_http_status(
ProviderKind::Ollama,
status,
&raw,
));
}
let reply: Response = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(error) => {
let raw = String::from_utf8_lossy(&body_bytes);
return Err(ProviderError::invalid_response(
ProviderKind::Ollama,
format!("Failed to parse JSON: {error}\nRaw response:\n{raw}"),
));
}
};
match reply.choices.first() {
Some(choice) => Ok(choice.message.content.clone()),
None => Err(ProviderError::invalid_response(
ProviderKind::Ollama,
"Ollama returned no chat completion choices.",
)),
}
}
fn normalize_images_url(url: &str) -> String {
let trimmed = url.trim_end_matches('/');
if let Some(index) = trimmed.find("/v1/") {
return format!("{}/v1/images/generations", &trimmed[..index]);
}
if trimmed.ends_with("/v1") {
format!("{trimmed}/images/generations")
} else {
format!("{trimmed}/v1/images/generations")
}
}
pub async fn send_image_request(
url: &String,
model: &String,
prompt: &str,
timeout_in_sec: u64,
token: &str,
) -> Result<Vec<u8>, ProviderError> {
let request = ImageGenerationRequest {
model: model.clone(),
prompt: prompt.to_string(),
n: 1,
size: DEFAULT_IMAGE_SIZE.to_string(),
response_format: "b64_json".to_string(),
};
let client = ClientBuilder::new()
.timeout(Duration::from_secs(timeout_in_sec))
.build()
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
let endpoint = normalize_images_url(url);
let mut request_builder = client
.post(endpoint.as_str())
.header("Content-Type", "application/json");
if !token.trim().is_empty() {
request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
}
let http_resp = request_builder
.json(&request)
.send()
.await
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
let status = http_resp.status();
let body_bytes = http_resp
.bytes()
.await
.map_err(|error| ProviderError::from_reqwest(ProviderKind::Ollama, error))?;
if !status.is_success() {
let raw = String::from_utf8_lossy(&body_bytes);
return Err(ProviderError::from_http_status(
ProviderKind::Ollama,
status,
&raw,
));
}
let response: ImageGenerationResponse =
serde_json::from_slice(&body_bytes).map_err(|error| {
let raw = String::from_utf8_lossy(&body_bytes);
ProviderError::invalid_response(
ProviderKind::Ollama,
format!("Failed to parse image-generation JSON: {error}\nRaw response:\n{raw}"),
)
})?;
let encoded_image = response
.data
.first()
.map(|image| image.b64_json.trim())
.filter(|image| !image.is_empty())
.ok_or_else(|| {
ProviderError::invalid_response(
ProviderKind::Ollama,
"Image generation response did not include `data[0].b64_json`.",
)
})?;
BASE64_STANDARD.decode(encoded_image).map_err(|error| {
ProviderError::invalid_response(
ProviderKind::Ollama,
format!("Failed to decode generated image bytes: {error}"),
)
})
}
fn request_content_parts(content_parts: &[ContentPart]) -> Vec<RequestContentPart> {
content_parts
.iter()
.map(|part| match part {
ContentPart::Text(text) => RequestContentPart::Text { text: text.clone() },
ContentPart::Image { data_url } => RequestContentPart::ImageUrl {
image_url: ImageUrl {
url: data_url.clone(),
},
},
ContentPart::File {
filename,
file_data,
} => RequestContentPart::File {
file: FileInput {
filename: filename.clone(),
file_data: file_data.clone(),
},
},
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use mockito::{Matcher, Server};
#[test]
fn wraps_plain_schema_response_format_for_ollama_chat_completions() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"ok": { "type": "boolean" }
},
"required": ["ok"]
});
let wrapped = normalize_response_format(schema.clone());
assert_eq!(
wrapped,
serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": schema,
"strict": true
}
})
);
}
#[tokio::test]
async fn test_send_request_with_mock() {
let mut server = Server::new_async().await;
let mock_path = "/v1/chat/completions";
let _m = server
.mock("POST", mock_path)
.match_header("content-type", "application/json")
.with_status(200)
.with_body(
r#"{
"choices": [
{
"message": {
"role": "assistant",
"content": "Mocked response"
}
}
]
}"#,
)
.create();
let result = send_request(
&format!("{}{}", server.url(), mock_path),
&"test-model".to_string(),
&[ContentPart::Text("test prompt".to_string())],
5,
serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": {
"type": "object",
"properties": { "ok": { "type": "boolean" } },
"required": ["ok"]
},
"strict": true
}
}),
)
.await
.expect("send_request failed");
assert_eq!(result, "Mocked response");
}
#[tokio::test]
async fn image_request_uses_images_endpoint_and_decodes_bytes() {
let mut server = Server::new_async().await;
let expected_bytes = b"fake-png";
let encoded_image = BASE64_STANDARD.encode(expected_bytes);
let _mock = server
.mock("POST", "/v1/images/generations")
.match_body(Matcher::PartialJson(serde_json::json!({
"model": "x/flux2-klein:4b",
"prompt": "draw a square",
"n": 1,
"size": "1024x1024",
"response_format": "b64_json"
})))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(
r#"{{"data":[{{"b64_json":"{}"}}]}}"#,
encoded_image
))
.create_async()
.await;
let url = format!("{}/v1/chat/completions", server.url());
let model = "x/flux2-klein:4b".to_string();
let image = send_image_request(&url, &model, "draw a square", 10, "")
.await
.expect("image request should decode");
assert_eq!(image, expected_bytes);
}
}