use crate::error::MoldError;
use crate::types::{
ExpandRequest, ExpandResponse, GalleryImage, GenerateRequest, GenerateResponse, ImageData,
ModelInfo, ModelInfoExtended, ServerStatus, SseCompleteEvent, SseErrorEvent, SseProgressEvent,
};
use anyhow::Result;
use base64::Engine as _;
use reqwest::Client;
pub struct MoldClient {
base_url: String,
client: Client,
}
impl MoldClient {
pub fn new(base_url: &str) -> Self {
Self {
base_url: normalize_host(base_url),
client: Client::new(),
}
}
pub fn from_env() -> Self {
let base_url =
std::env::var("MOLD_HOST").unwrap_or_else(|_| "http://localhost:7680".to_string());
Self::new(&base_url)
}
pub async fn generate_raw(&self, req: &GenerateRequest) -> Result<Vec<u8>> {
let bytes = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(req)
.send()
.await?
.error_for_status()?
.bytes()
.await?
.to_vec();
Ok(bytes)
}
pub async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse> {
let fallback_seed = req.seed.unwrap_or(0);
let width = req.width;
let height = req.height;
let model = req.model.clone();
let format = req.output_format;
let start = std::time::Instant::now();
let resp = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(&req)
.send()
.await?
.error_for_status()?;
let seed_used = resp
.headers()
.get("x-mold-seed-used")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(fallback_seed);
let data = resp.bytes().await?.to_vec();
let generation_time_ms = start.elapsed().as_millis() as u64;
Ok(GenerateResponse {
images: vec![ImageData {
data,
format,
width,
height,
index: 0,
}],
generation_time_ms,
model,
seed_used,
})
}
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let models = self.list_models_extended().await?;
Ok(models.into_iter().map(|m| m.info).collect())
}
pub async fn list_models_extended(&self) -> Result<Vec<ModelInfoExtended>> {
let resp = self
.client
.get(format!("{}/api/models", self.base_url))
.send()
.await?
.error_for_status()?
.json::<Vec<ModelInfoExtended>>()
.await?;
Ok(resp)
}
pub fn is_connection_error(err: &anyhow::Error) -> bool {
if let Some(mold_err) = err.downcast_ref::<MoldError>() {
if matches!(mold_err, MoldError::Client(_)) {
return true;
}
}
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
return reqwest_err.is_connect();
}
false
}
pub fn is_model_not_found(err: &anyhow::Error) -> bool {
if let Some(mold_err) = err.downcast_ref::<MoldError>() {
if matches!(mold_err, MoldError::ModelNotFound(_)) {
return true;
}
}
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
return reqwest_err.status() == Some(reqwest::StatusCode::NOT_FOUND);
}
err.downcast_ref::<ModelNotFoundError>().is_some()
}
pub async fn generate_stream(
&self,
req: &GenerateRequest,
progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
) -> Result<Option<GenerateResponse>> {
let mut resp = self
.client
.post(format!("{}/api/generate/stream", self.base_url))
.json(req)
.send()
.await?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
let body = resp.text().await.unwrap_or_default();
if body.is_empty() {
return Ok(None);
}
return Err(MoldError::ModelNotFound(body).into());
}
if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY {
let body = resp.text().await.unwrap_or_default();
return Err(MoldError::Validation(format!("validation error: {body}")).into());
}
if resp.status().is_client_error() || resp.status().is_server_error() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("server error {status}: {body}");
}
let mut buffer = String::new();
while let Some(chunk) = resp.chunk().await? {
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find("\n\n") {
let event_text = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
let mut event_type = String::new();
let mut data = String::new();
for line in event_text.lines() {
if line.starts_with(':') {
continue; }
if let Some(t) = line.strip_prefix("event:") {
event_type = t.trim().to_string();
} else if let Some(d) = line.strip_prefix("data:") {
data = d.trim().to_string();
}
}
match event_type.as_str() {
"progress" => {
if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
let _ = progress_tx.send(p);
}
}
"complete" => {
let complete: SseCompleteEvent = serde_json::from_str(&data)?;
let image_data =
base64::engine::general_purpose::STANDARD.decode(&complete.image)?;
return Ok(Some(GenerateResponse {
images: vec![ImageData {
data: image_data,
format: complete.format,
width: complete.width,
height: complete.height,
index: 0,
}],
generation_time_ms: complete.generation_time_ms,
model: req.model.clone(),
seed_used: complete.seed_used,
}));
}
"error" => {
let error: SseErrorEvent = serde_json::from_str(&data)?;
anyhow::bail!("server error: {}", error.message);
}
_ => {}
}
}
}
anyhow::bail!("SSE stream ended without complete event")
}
pub async fn pull_model(&self, model: &str) -> Result<String> {
let resp = self
.client
.post(format!("{}/api/models/pull", self.base_url))
.json(&serde_json::json!({ "model": model }))
.send()
.await?
.error_for_status()?
.text()
.await?;
Ok(resp)
}
pub async fn pull_model_stream(
&self,
model: &str,
progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
) -> Result<()> {
let mut resp = self
.client
.post(format!("{}/api/models/pull", self.base_url))
.header("Accept", "text/event-stream")
.json(&serde_json::json!({ "model": model }))
.send()
.await?;
if resp.status().is_client_error() || resp.status().is_server_error() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("server error {status}: {body}");
}
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.contains("text/event-stream") {
drop(progress_tx);
let _ = resp.text().await?;
return Ok(());
}
let mut buffer = String::new();
while let Some(chunk) = resp.chunk().await? {
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find("\n\n") {
let event_text = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
let mut event_type = String::new();
let mut data = String::new();
for line in event_text.lines() {
if line.starts_with(':') {
continue;
}
if let Some(t) = line.strip_prefix("event:") {
event_type = t.trim().to_string();
} else if let Some(d) = line.strip_prefix("data:") {
data = d.trim().to_string();
}
}
match event_type.as_str() {
"progress" => {
if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
let is_done = matches!(p, SseProgressEvent::PullComplete { .. });
let _ = progress_tx.send(p);
if is_done {
return Ok(());
}
}
}
"error" => {
let error: SseErrorEvent = serde_json::from_str(&data)?;
anyhow::bail!("server error: {}", error.message);
}
_ => {}
}
}
}
Ok(())
}
pub fn host(&self) -> &str {
&self.base_url
}
pub async fn unload_model(&self) -> Result<String> {
let resp = self
.client
.delete(format!("{}/api/models/unload", self.base_url))
.send()
.await?
.error_for_status()?
.text()
.await?;
Ok(resp)
}
pub async fn server_status(&self) -> Result<ServerStatus> {
let resp = self
.client
.get(format!("{}/api/status", self.base_url))
.send()
.await?
.error_for_status()?
.json::<ServerStatus>()
.await?;
Ok(resp)
}
pub async fn list_gallery(&self) -> Result<Vec<GalleryImage>> {
let resp = self
.client
.get(format!("{}/api/gallery", self.base_url))
.send()
.await?
.error_for_status()?
.json::<Vec<GalleryImage>>()
.await?;
Ok(resp)
}
pub async fn get_gallery_image(&self, filename: &str) -> Result<Vec<u8>> {
let resp = self
.client
.get(format!("{}/api/gallery/image/{filename}", self.base_url))
.send()
.await?
.error_for_status()?
.bytes()
.await?;
Ok(resp.to_vec())
}
pub async fn delete_gallery_image(&self, filename: &str) -> Result<()> {
self.client
.delete(format!("{}/api/gallery/image/{filename}", self.base_url))
.send()
.await?
.error_for_status()?;
Ok(())
}
pub async fn get_gallery_thumbnail(&self, filename: &str) -> Result<Vec<u8>> {
let resp = self
.client
.get(format!(
"{}/api/gallery/thumbnail/{filename}",
self.base_url
))
.send()
.await?
.error_for_status()?
.bytes()
.await?;
Ok(resp.to_vec())
}
pub async fn expand_prompt(&self, req: &ExpandRequest) -> Result<ExpandResponse> {
let resp = self
.client
.post(format!("{}/api/expand", self.base_url))
.json(req)
.send()
.await?
.error_for_status()?
.json::<ExpandResponse>()
.await?;
Ok(resp)
}
}
pub fn normalize_host(input: &str) -> String {
let trimmed = input.trim().trim_end_matches('/');
if trimmed.contains("://") {
trimmed.to_string()
} else if trimmed.contains(':') {
format!("http://{trimmed}")
} else {
format!("http://{trimmed}:7680")
}
}
#[derive(Debug)]
pub struct ModelNotFoundError(pub String);
impl std::fmt::Display for ModelNotFoundError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for ModelNotFoundError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::ENV_LOCK;
#[test]
fn test_new_trims_trailing_slash() {
let client = MoldClient::new("http://localhost:7680/");
assert_eq!(client.host(), "http://localhost:7680");
}
#[test]
fn test_new_no_slash_unchanged() {
let client = MoldClient::new("http://localhost:7680");
assert_eq!(client.host(), "http://localhost:7680");
}
#[test]
fn test_new_multiple_slashes() {
let client = MoldClient::new("http://localhost:7680///");
assert_eq!(client.host(), "http://localhost:7680");
}
#[test]
fn test_from_env_mold_host() {
let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::remove_var("MOLD_HOST") };
let client = MoldClient::from_env();
assert_eq!(client.host(), "http://localhost:7680");
let unique_url = "http://test-host-env:9999";
unsafe { std::env::set_var("MOLD_HOST", unique_url) };
let client = MoldClient::from_env();
assert_eq!(client.host(), unique_url);
unsafe { std::env::remove_var("MOLD_HOST") };
}
#[test]
fn test_is_connection_error_non_connect() {
let err = anyhow::anyhow!("something went wrong");
assert!(!MoldClient::is_connection_error(&err));
}
#[test]
fn test_is_model_not_found_via_custom_error() {
let err: anyhow::Error =
ModelNotFoundError("model 'test' is not downloaded".to_string()).into();
assert!(MoldClient::is_model_not_found(&err));
}
#[test]
fn test_is_model_not_found_generic_error() {
let err = anyhow::anyhow!("something else");
assert!(!MoldClient::is_model_not_found(&err));
}
#[test]
fn test_normalize_bare_hostname() {
let client = MoldClient::new("hal9000");
assert_eq!(client.host(), "http://hal9000:7680");
}
#[test]
fn test_normalize_hostname_with_port() {
let client = MoldClient::new("hal9000:8080");
assert_eq!(client.host(), "http://hal9000:8080");
}
#[test]
fn test_normalize_full_url_unchanged() {
let client = MoldClient::new("http://hal9000:7680");
assert_eq!(client.host(), "http://hal9000:7680");
}
#[test]
fn test_normalize_https_no_port() {
let client = MoldClient::new("https://hal9000");
assert_eq!(client.host(), "https://hal9000");
}
#[test]
fn test_normalize_http_no_port() {
let client = MoldClient::new("http://hal9000");
assert_eq!(client.host(), "http://hal9000");
}
#[test]
fn test_normalize_localhost() {
let client = MoldClient::new("localhost");
assert_eq!(client.host(), "http://localhost:7680");
}
#[test]
fn test_normalize_whitespace_trimmed() {
let client = MoldClient::new(" hal9000 ");
assert_eq!(client.host(), "http://hal9000:7680");
}
#[test]
fn test_normalize_ip_address() {
let client = MoldClient::new("192.168.1.100");
assert_eq!(client.host(), "http://192.168.1.100:7680");
}
#[test]
fn test_normalize_ip_with_port() {
let client = MoldClient::new("192.168.1.100:9090");
assert_eq!(client.host(), "http://192.168.1.100:9090");
}
#[test]
fn test_is_model_not_found_via_mold_error() {
let err: anyhow::Error =
MoldError::ModelNotFound("model 'test' is not downloaded".to_string()).into();
assert!(MoldClient::is_model_not_found(&err));
}
#[test]
fn test_is_connection_error_via_mold_error() {
let err: anyhow::Error = MoldError::Client("connection refused".to_string()).into();
assert!(MoldClient::is_connection_error(&err));
}
}