use crate::types::*;
use anyhow::{anyhow, Context, Result};
use reqwest::blocking::{Client, Response};
use std::time::{Duration, Instant};
use tracing::{debug, warn};
const API_PREFIX: &str = "/graphics/api";
const TRACE_TARGET: &str = "studio_worker::http";
pub struct ApiClient {
pub base_url: String,
pub client: Client,
}
impl ApiClient {
pub fn new(base_url: String) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(60))
.build()
.context("building reqwest client")?;
Ok(Self {
base_url: normalize_base_url(&base_url)?,
client,
})
}
fn url(&self, path: &str) -> String {
format!("{}{}{}", self.base_url, API_PREFIX, path)
}
fn check(&self, op: &str, url: &str, started: Instant, response: Response) -> Result<Response> {
let status = response.status();
let elapsed_ms = started.elapsed().as_millis() as u64;
if status.is_success() || status.as_u16() == 204 {
debug!(
target: TRACE_TARGET,
op,
endpoint = %url,
status = status.as_u16(),
elapsed_ms,
"ok"
);
return Ok(response);
}
let body = response.text().unwrap_or_default();
warn!(
target: TRACE_TARGET,
op,
endpoint = %url,
status = status.as_u16(),
elapsed_ms,
body = %body,
"{op} failed"
);
Err(anyhow!("{op} failed: {status} — {body}"))
}
pub fn register_request(
&self,
payload: &AutoRegisterRequest,
) -> Result<AutoRegisterRequestResponse> {
let url = self.url("/workers/register-request");
let started = Instant::now();
let response = self.client.post(&url).json(payload).send()?;
let response = self.check("register-request", &url, started, response)?;
Ok(response.json()?)
}
pub fn poll_register_status(
&self,
request_id: &str,
registration_secret: &str,
) -> Result<Option<RegisterStatus>> {
let url = self.url(&format!("/workers/register-requests/{request_id}"));
let started = Instant::now();
let response = self
.client
.get(&url)
.bearer_auth(registration_secret)
.send()?;
if response.status().as_u16() == 404 {
debug!(
target: TRACE_TARGET,
op = "register-poll",
endpoint = %url,
status = 404,
elapsed_ms = started.elapsed().as_millis() as u64,
"register request not found (stale id; orchestrator will recreate)"
);
return Ok(None);
}
let response = self.check("register-poll", &url, started, response)?;
Ok(Some(response.json()?))
}
pub fn complete(
&self,
worker_id: &str,
token: &str,
job_id: &str,
ext: &str,
prompt: &str,
image: Vec<u8>,
) -> Result<()> {
let mime = mime_for_ext(ext);
let bytes = image.len() as u64;
debug!(
target: TRACE_TARGET,
op = "complete",
job_id,
ext,
mime,
bytes,
"uploading job result"
);
let part = reqwest::blocking::multipart::Part::bytes(image)
.file_name(format!("{job_id}.{ext}"))
.mime_str(mime)?;
let form = reqwest::blocking::multipart::Form::new()
.text("prompt", prompt.to_string())
.text("ext", ext.to_string())
.part("image", part);
let url = self.url(&format!("/workers/{worker_id}/jobs/{job_id}/complete"));
let started = Instant::now();
let response = self
.client
.post(&url)
.bearer_auth(token)
.multipart(form)
.send()?;
self.check("complete", &url, started, response)?;
Ok(())
}
}
fn normalize_base_url(base_url: &str) -> Result<String> {
let mut url =
url::Url::parse(base_url).map_err(|e| anyhow!("invalid api_base_url {base_url:?}: {e}"))?;
url.set_query(None);
url.set_fragment(None);
let trimmed_path = url.path().trim_end_matches('/').to_string();
if trimmed_path.ends_with(API_PREFIX) {
let without_prefix = trimmed_path[..trimmed_path.len() - API_PREFIX.len()].to_string();
url.set_path(if without_prefix.is_empty() {
"/"
} else {
&without_prefix
});
}
Ok(url.as_str().trim_end_matches('/').to_string())
}
pub fn mime_for_ext(ext: &str) -> &'static str {
match ext {
"png" => "image/png",
"webp" => "image/webp",
"gif" => "image/gif",
"wav" => "audio/wav",
"mp3" => "audio/mpeg",
"mp4" => "video/mp4",
_ => "application/octet-stream",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mime_for_ext_maps_known_image_audio_video_types() {
assert_eq!(mime_for_ext("png"), "image/png");
assert_eq!(mime_for_ext("webp"), "image/webp");
assert_eq!(mime_for_ext("gif"), "image/gif");
assert_eq!(mime_for_ext("wav"), "audio/wav");
assert_eq!(mime_for_ext("mp3"), "audio/mpeg");
assert_eq!(mime_for_ext("mp4"), "video/mp4");
}
#[test]
fn mime_for_ext_falls_back_to_octet_stream_for_unknown() {
assert_eq!(mime_for_ext("bin"), "application/octet-stream");
assert_eq!(mime_for_ext(""), "application/octet-stream");
}
#[test]
fn normalize_base_url_strips_existing_graphics_api_prefix() {
let api = ApiClient::new("https://studio.example/graphics/api/".into()).unwrap();
assert_eq!(
api.url("/workers/register-request"),
"https://studio.example/graphics/api/workers/register-request"
);
}
#[test]
fn normalize_base_url_preserves_outer_mount_path() {
let api = ApiClient::new("https://studio.example/custom/graphics/api".into()).unwrap();
assert_eq!(
api.url("/workers/register-request"),
"https://studio.example/custom/graphics/api/workers/register-request"
);
}
#[test]
fn mime_for_ext_covers_every_extension_engines_emit() {
for ext in ["png", "webp", "gif", "wav"] {
assert_ne!(
mime_for_ext(ext),
"application/octet-stream",
"engine output extension {ext:?} must map to a real MIME type"
);
}
}
}