Skip to main content

mold_core/
client.rs

1use crate::error::MoldError;
2use crate::types::{
3    ExpandRequest, ExpandResponse, GalleryImage, GenerateRequest, GenerateResponse, ImageData,
4    ModelInfo, ModelInfoExtended, ServerStatus, SseCompleteEvent, SseErrorEvent, SseProgressEvent,
5    VideoData,
6};
7use anyhow::Result;
8use base64::Engine as _;
9use reqwest::Client;
10
11pub struct MoldClient {
12    base_url: String,
13    client: Client,
14}
15
16impl MoldClient {
17    pub fn new(base_url: &str) -> Self {
18        let client = build_client(None);
19        Self {
20            base_url: normalize_host(base_url),
21            client,
22        }
23    }
24
25    /// Create a client with an explicit API key for authentication.
26    pub fn with_api_key(base_url: &str, api_key: String) -> Self {
27        let client = build_client(Some(&api_key));
28        Self {
29            base_url: normalize_host(base_url),
30            client,
31        }
32    }
33
34    pub fn from_env() -> Self {
35        let base_url =
36            std::env::var("MOLD_HOST").unwrap_or_else(|_| "http://localhost:7680".to_string());
37        let api_key = std::env::var("MOLD_API_KEY").ok().filter(|k| !k.is_empty());
38        let client = build_client(api_key.as_deref());
39        Self {
40            base_url: normalize_host(&base_url),
41            client,
42        }
43    }
44
45    /// Generate an image. Returns raw image bytes (PNG or JPEG).
46    /// The server returns raw bytes, not JSON — callers are responsible for
47    /// writing the bytes to disk or further processing.
48    pub async fn generate_raw(&self, req: &GenerateRequest) -> Result<Vec<u8>> {
49        let bytes = self
50            .client
51            .post(format!("{}/api/generate", self.base_url))
52            .json(req)
53            .send()
54            .await?
55            .error_for_status()?
56            .bytes()
57            .await?
58            .to_vec();
59        Ok(bytes)
60    }
61
62    /// Generate an image or video and return the response wrapping the raw bytes.
63    ///
64    /// For video responses the server sends `x-mold-video-*` metadata headers
65    /// alongside the raw video bytes so we can reconstruct [`VideoData`].
66    pub async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse> {
67        let fallback_seed = req.seed.unwrap_or(0);
68        let width = req.width;
69        let height = req.height;
70        let model = req.model.clone();
71        let format = req.output_format;
72
73        let start = std::time::Instant::now();
74        let resp = self
75            .client
76            .post(format!("{}/api/generate", self.base_url))
77            .json(&req)
78            .send()
79            .await?
80            .error_for_status()?;
81
82        // Read the seed the server actually used from the response header.
83        // Fall back to the request seed for backward compat with older servers.
84        let seed_used = resp
85            .headers()
86            .get("x-mold-seed-used")
87            .and_then(|v| v.to_str().ok())
88            .and_then(|s| s.parse::<u64>().ok())
89            .unwrap_or(fallback_seed);
90
91        // Detect video response via x-mold-video-frames header
92        let video_meta = parse_video_headers(resp.headers());
93
94        let data = resp.bytes().await?.to_vec();
95        let generation_time_ms = start.elapsed().as_millis() as u64;
96
97        let video = video_meta.map(|meta| VideoData {
98            data: data.clone(),
99            format,
100            width: meta.width.unwrap_or(width),
101            height: meta.height.unwrap_or(height),
102            frames: meta.frames,
103            fps: meta.fps,
104            thumbnail: Vec::new(),
105            gif_preview: Vec::new(),
106            has_audio: meta.has_audio,
107            duration_ms: meta.duration_ms,
108            audio_sample_rate: meta.audio_sample_rate,
109            audio_channels: meta.audio_channels,
110        });
111
112        // For video responses, images is empty — the payload lives in `video`.
113        let images = if video.is_some() {
114            Vec::new()
115        } else {
116            vec![ImageData {
117                data,
118                format,
119                width,
120                height,
121                index: 0,
122            }]
123        };
124
125        Ok(GenerateResponse {
126            images,
127            generation_time_ms,
128            model,
129            seed_used,
130            video,
131        })
132    }
133
134    pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
135        let models = self.list_models_extended().await?;
136        Ok(models.into_iter().map(|m| m.info).collect())
137    }
138
139    pub async fn list_models_extended(&self) -> Result<Vec<ModelInfoExtended>> {
140        let resp = self
141            .client
142            .get(format!("{}/api/models", self.base_url))
143            .send()
144            .await?
145            .error_for_status()?
146            .json::<Vec<ModelInfoExtended>>()
147            .await?;
148        Ok(resp)
149    }
150
151    /// Check whether an error is a connection error (e.g. "connection refused").
152    /// Useful for deciding whether to fall back to local inference.
153    pub fn is_connection_error(err: &anyhow::Error) -> bool {
154        // Check for MoldError::Client variant
155        if let Some(mold_err) = err.downcast_ref::<MoldError>() {
156            if matches!(mold_err, MoldError::Client(_)) {
157                return true;
158            }
159        }
160        if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
161            return reqwest_err.is_connect();
162        }
163        false
164    }
165
166    /// Check whether an error is a 404 "model not found" from the server.
167    /// Useful for triggering a server-side pull when the model isn't downloaded.
168    pub fn is_model_not_found(err: &anyhow::Error) -> bool {
169        // Check for MoldError::ModelNotFound variant
170        if let Some(mold_err) = err.downcast_ref::<MoldError>() {
171            if matches!(mold_err, MoldError::ModelNotFound(_)) {
172                return true;
173            }
174        }
175        if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
176            return reqwest_err.status() == Some(reqwest::StatusCode::NOT_FOUND);
177        }
178        // SSE streaming returns ModelNotFoundError instead of reqwest status errors
179        err.downcast_ref::<ModelNotFoundError>().is_some()
180    }
181
182    /// Generate an image via SSE streaming, receiving progress events.
183    ///
184    /// Returns:
185    /// - `Ok(Some(response))` — streaming succeeded
186    /// - `Ok(None)` — server doesn't support SSE (endpoint returned 404 with empty body)
187    /// - `Err(e)` — generation error, model not found, or connection error
188    pub async fn generate_stream(
189        &self,
190        req: &GenerateRequest,
191        progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
192    ) -> Result<Option<GenerateResponse>> {
193        let mut resp = self
194            .client
195            .post(format!("{}/api/generate/stream", self.base_url))
196            .json(req)
197            .send()
198            .await?;
199
200        if resp.status() == reqwest::StatusCode::NOT_FOUND {
201            let body = resp.text().await.unwrap_or_default();
202            if body.is_empty() {
203                // Axum returns empty 404 for unmatched routes — server doesn't support SSE
204                return Ok(None);
205            }
206            // Non-empty 404 = model not found
207            return Err(MoldError::ModelNotFound(body).into());
208        }
209
210        if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY {
211            let body = resp.text().await.unwrap_or_default();
212            return Err(MoldError::Validation(format!("validation error: {body}")).into());
213        }
214
215        if resp.status().is_client_error() || resp.status().is_server_error() {
216            let status = resp.status();
217            let body = resp.text().await.unwrap_or_default();
218            anyhow::bail!("server error {status}: {body}");
219        }
220
221        // Parse SSE events from chunked response body
222        let mut buffer = String::new();
223        while let Some(chunk) = resp.chunk().await? {
224            buffer.push_str(&String::from_utf8_lossy(&chunk));
225
226            while let Some(event_text) = next_sse_event(&mut buffer) {
227                let (event_type, data) = parse_sse_event(&event_text);
228                match event_type.as_str() {
229                    "progress" => {
230                        if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
231                            let _ = progress_tx.send(p);
232                        }
233                    }
234                    "complete" => {
235                        let complete: SseCompleteEvent = serde_json::from_str(&data)?;
236                        let payload =
237                            base64::engine::general_purpose::STANDARD.decode(&complete.image)?;
238                        let b64 = base64::engine::general_purpose::STANDARD;
239                        // Use server-provided model name (source of truth);
240                        // fall back to request model for backwards compat with
241                        // older servers that don't include it.
242                        let model = if complete.model.is_empty() {
243                            req.model.clone()
244                        } else {
245                            complete.model
246                        };
247
248                        // Detect video response via video_frames field
249                        let (images, video) = if let (Some(frames), Some(fps)) =
250                            (complete.video_frames, complete.video_fps)
251                        {
252                            let thumbnail = complete
253                                .video_thumbnail
254                                .as_deref()
255                                .and_then(|s| b64.decode(s).ok())
256                                .unwrap_or_default();
257                            let gif_preview = complete
258                                .video_gif_preview
259                                .as_deref()
260                                .and_then(|s| b64.decode(s).ok())
261                                .unwrap_or_default();
262                            let vd = VideoData {
263                                data: payload,
264                                format: complete.format,
265                                width: complete.width,
266                                height: complete.height,
267                                frames,
268                                fps,
269                                thumbnail,
270                                gif_preview,
271                                has_audio: complete.video_has_audio,
272                                duration_ms: complete.video_duration_ms,
273                                audio_sample_rate: complete.video_audio_sample_rate,
274                                audio_channels: complete.video_audio_channels,
275                            };
276                            (Vec::new(), Some(vd))
277                        } else {
278                            let img = ImageData {
279                                data: payload,
280                                format: complete.format,
281                                width: complete.width,
282                                height: complete.height,
283                                index: 0,
284                            };
285                            (vec![img], None)
286                        };
287
288                        return Ok(Some(GenerateResponse {
289                            images,
290                            generation_time_ms: complete.generation_time_ms,
291                            model,
292                            seed_used: complete.seed_used,
293                            video,
294                        }));
295                    }
296                    "error" => {
297                        let error: SseErrorEvent = serde_json::from_str(&data)?;
298                        anyhow::bail!("server error: {}", error.message);
299                    }
300                    _ => {}
301                }
302            }
303        }
304
305        anyhow::bail!("SSE stream ended without complete event")
306    }
307
308    /// Ask the server to pull (download) a model. Blocks until the download
309    /// completes on the server side. The server updates its in-memory config
310    /// so subsequent generate/load requests can find the model.
311    pub async fn pull_model(&self, model: &str) -> Result<String> {
312        let resp = self
313            .client
314            .post(format!("{}/api/models/pull", self.base_url))
315            .json(&serde_json::json!({ "model": model }))
316            .send()
317            .await?
318            .error_for_status()?
319            .text()
320            .await?;
321        Ok(resp)
322    }
323
324    /// Request graceful server shutdown.
325    pub async fn shutdown_server(&self) -> Result<()> {
326        self.client
327            .post(format!("{}/api/shutdown", self.base_url))
328            .send()
329            .await?
330            .error_for_status()?;
331        Ok(())
332    }
333
334    /// Pull a model via SSE streaming, receiving download progress events.
335    ///
336    /// Sends `Accept: text/event-stream` to request SSE from the server.
337    /// Falls back to blocking pull if the server doesn't support SSE.
338    pub async fn pull_model_stream(
339        &self,
340        model: &str,
341        progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
342    ) -> Result<()> {
343        let mut resp = self
344            .client
345            .post(format!("{}/api/models/pull", self.base_url))
346            .header("Accept", "text/event-stream")
347            .json(&serde_json::json!({ "model": model }))
348            .send()
349            .await?;
350
351        if resp.status().is_client_error() || resp.status().is_server_error() {
352            let status = resp.status();
353            let body = resp.text().await.unwrap_or_default();
354            anyhow::bail!("server error {status}: {body}");
355        }
356
357        // Check if server returned SSE or plain text
358        let content_type = resp
359            .headers()
360            .get("content-type")
361            .and_then(|v| v.to_str().ok())
362            .unwrap_or("");
363
364        if !content_type.contains("text/event-stream") {
365            // Old server — blocking pull, no progress. Just consume the response.
366            // Drop the sender so the receiver's recv() returns None instead of blocking.
367            drop(progress_tx);
368            let _ = resp.text().await?;
369            return Ok(());
370        }
371
372        // Parse SSE events (same pattern as generate_stream)
373        let mut buffer = String::new();
374        while let Some(chunk) = resp.chunk().await? {
375            buffer.push_str(&String::from_utf8_lossy(&chunk));
376
377            while let Some(event_text) = next_sse_event(&mut buffer) {
378                let (event_type, data) = parse_sse_event(&event_text);
379                match event_type.as_str() {
380                    "progress" => {
381                        if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
382                            // PullComplete signals end of pull
383                            let is_done = matches!(p, SseProgressEvent::PullComplete { .. });
384                            let _ = progress_tx.send(p);
385                            if is_done {
386                                return Ok(());
387                            }
388                        }
389                    }
390                    "error" => {
391                        let error: SseErrorEvent = serde_json::from_str(&data)?;
392                        anyhow::bail!("server error: {}", error.message);
393                    }
394                    _ => {}
395                }
396            }
397        }
398
399        Ok(())
400    }
401
402    pub fn host(&self) -> &str {
403        &self.base_url
404    }
405
406    pub async fn unload_model(&self) -> Result<String> {
407        let resp = self
408            .client
409            .delete(format!("{}/api/models/unload", self.base_url))
410            .send()
411            .await?
412            .error_for_status()?
413            .text()
414            .await?;
415        Ok(resp)
416    }
417
418    pub async fn server_status(&self) -> Result<ServerStatus> {
419        let resp = self
420            .client
421            .get(format!("{}/api/status", self.base_url))
422            .send()
423            .await?
424            .error_for_status()?
425            .json::<ServerStatus>()
426            .await?;
427        Ok(resp)
428    }
429
430    /// List gallery images from the server's output directory.
431    pub async fn list_gallery(&self) -> Result<Vec<GalleryImage>> {
432        let resp = self
433            .client
434            .get(format!("{}/api/gallery", self.base_url))
435            .send()
436            .await?
437            .error_for_status()?
438            .json::<Vec<GalleryImage>>()
439            .await?;
440        Ok(resp)
441    }
442
443    /// Download a gallery image by filename.
444    pub async fn get_gallery_image(&self, filename: &str) -> Result<Vec<u8>> {
445        let resp = self
446            .client
447            .get(format!("{}/api/gallery/image/{filename}", self.base_url))
448            .send()
449            .await?
450            .error_for_status()?
451            .bytes()
452            .await?;
453        Ok(resp.to_vec())
454    }
455
456    /// Delete a gallery image on the server.
457    pub async fn delete_gallery_image(&self, filename: &str) -> Result<()> {
458        self.client
459            .delete(format!("{}/api/gallery/image/{filename}", self.base_url))
460            .send()
461            .await?
462            .error_for_status()?;
463        Ok(())
464    }
465
466    /// Download a gallery thumbnail by filename. Smaller/faster than full image.
467    pub async fn get_gallery_thumbnail(&self, filename: &str) -> Result<Vec<u8>> {
468        let resp = self
469            .client
470            .get(format!(
471                "{}/api/gallery/thumbnail/{filename}",
472                self.base_url
473            ))
474            .send()
475            .await?
476            .error_for_status()?
477            .bytes()
478            .await?;
479        Ok(resp.to_vec())
480    }
481
482    /// Expand a prompt using the server's LLM prompt expansion endpoint.
483    pub async fn expand_prompt(&self, req: &ExpandRequest) -> Result<ExpandResponse> {
484        let resp = self
485            .client
486            .post(format!("{}/api/expand", self.base_url))
487            .json(req)
488            .send()
489            .await?
490            .error_for_status()?
491            .json::<ExpandResponse>()
492            .await?;
493        Ok(resp)
494    }
495
496    /// Upscale an image using a super-resolution model on the server.
497    pub async fn upscale(&self, req: &crate::UpscaleRequest) -> Result<crate::UpscaleResponse> {
498        let resp = self
499            .client
500            .post(format!("{}/api/upscale", self.base_url))
501            .json(req)
502            .send()
503            .await?
504            .error_for_status()?
505            .json::<crate::UpscaleResponse>()
506            .await?;
507        Ok(resp)
508    }
509
510    /// Upscale an image via SSE streaming -- progress events are sent to `progress_tx`,
511    /// returns the final `UpscaleResponse` on success.
512    pub async fn upscale_stream(
513        &self,
514        req: &crate::UpscaleRequest,
515        progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
516    ) -> Result<Option<crate::UpscaleResponse>> {
517        let mut resp = self
518            .client
519            .post(format!("{}/api/upscale/stream", self.base_url))
520            .json(req)
521            .send()
522            .await?;
523
524        if resp.status() == reqwest::StatusCode::NOT_FOUND {
525            let body = resp.text().await.unwrap_or_default();
526            if body.is_empty() {
527                return Ok(None); // server doesn't support SSE upscale
528            }
529            return Err(MoldError::ModelNotFound(body).into());
530        }
531
532        if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY {
533            let body = resp.text().await.unwrap_or_default();
534            return Err(MoldError::Validation(format!("validation error: {body}")).into());
535        }
536
537        if resp.status().is_client_error() || resp.status().is_server_error() {
538            let status = resp.status();
539            let body = resp.text().await.unwrap_or_default();
540            anyhow::bail!("server error {status}: {body}");
541        }
542
543        let mut buffer = String::new();
544        while let Some(chunk) = resp.chunk().await? {
545            buffer.push_str(&String::from_utf8_lossy(&chunk));
546
547            while let Some(event_text) = next_sse_event(&mut buffer) {
548                let (event_type, data) = parse_sse_event(&event_text);
549                match event_type.as_str() {
550                    "progress" => {
551                        if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
552                            let _ = progress_tx.send(p);
553                        }
554                    }
555                    "complete" => {
556                        let complete: crate::SseUpscaleCompleteEvent = serde_json::from_str(&data)?;
557                        let image_data =
558                            base64::engine::general_purpose::STANDARD.decode(&complete.image)?;
559                        return Ok(Some(crate::UpscaleResponse {
560                            image: crate::ImageData {
561                                data: image_data,
562                                format: complete.format,
563                                width: complete.original_width * complete.scale_factor,
564                                height: complete.original_height * complete.scale_factor,
565                                index: 0,
566                            },
567                            upscale_time_ms: complete.upscale_time_ms,
568                            model: complete.model,
569                            scale_factor: complete.scale_factor,
570                            original_width: complete.original_width,
571                            original_height: complete.original_height,
572                        }));
573                    }
574                    "error" => {
575                        let error: crate::SseErrorEvent = serde_json::from_str(&data)?;
576                        anyhow::bail!("server error: {}", error.message);
577                    }
578                    _ => {}
579                }
580            }
581        }
582
583        anyhow::bail!("SSE stream ended without complete event")
584    }
585}
586
587/// Parsed video metadata from `x-mold-video-*` response headers.
588struct VideoMeta {
589    frames: u32,
590    fps: u32,
591    width: Option<u32>,
592    height: Option<u32>,
593    has_audio: bool,
594    duration_ms: Option<u64>,
595    audio_sample_rate: Option<u32>,
596    audio_channels: Option<u32>,
597}
598
599/// Parse video metadata from HTTP response headers.
600/// Returns `Some` when `x-mold-video-frames` is present, indicating a video response.
601fn parse_video_headers(headers: &reqwest::header::HeaderMap) -> Option<VideoMeta> {
602    let frames = headers
603        .get("x-mold-video-frames")
604        .and_then(|v| v.to_str().ok())
605        .and_then(|s| s.parse::<u32>().ok())?;
606    let fps = headers
607        .get("x-mold-video-fps")
608        .and_then(|v| v.to_str().ok())
609        .and_then(|s| s.parse::<u32>().ok())
610        .unwrap_or(24);
611    let width = headers
612        .get("x-mold-video-width")
613        .and_then(|v| v.to_str().ok())
614        .and_then(|s| s.parse::<u32>().ok());
615    let height = headers
616        .get("x-mold-video-height")
617        .and_then(|v| v.to_str().ok())
618        .and_then(|s| s.parse::<u32>().ok());
619    let has_audio = headers
620        .get("x-mold-video-has-audio")
621        .and_then(|v| v.to_str().ok())
622        .map(|s| s == "1")
623        .unwrap_or(false);
624    let duration_ms = headers
625        .get("x-mold-video-duration-ms")
626        .and_then(|v| v.to_str().ok())
627        .and_then(|s| s.parse::<u64>().ok());
628    let audio_sample_rate = headers
629        .get("x-mold-video-audio-sample-rate")
630        .and_then(|v| v.to_str().ok())
631        .and_then(|s| s.parse::<u32>().ok());
632    let audio_channels = headers
633        .get("x-mold-video-audio-channels")
634        .and_then(|v| v.to_str().ok())
635        .and_then(|s| s.parse::<u32>().ok());
636
637    Some(VideoMeta {
638        frames,
639        fps,
640        width,
641        height,
642        has_audio,
643        duration_ms,
644        audio_sample_rate,
645        audio_channels,
646    })
647}
648
649fn next_sse_event(buffer: &mut String) -> Option<String> {
650    for separator in ["\r\n\r\n", "\n\n"] {
651        if let Some(pos) = buffer.find(separator) {
652            let event_text = buffer[..pos].to_string();
653            *buffer = buffer[pos + separator.len()..].to_string();
654            return Some(event_text);
655        }
656    }
657    None
658}
659
660fn parse_sse_event(event_text: &str) -> (String, String) {
661    let mut event_type = String::new();
662    let mut data_lines = Vec::new();
663    for line in event_text.lines() {
664        if line.starts_with(':') {
665            continue;
666        }
667        if let Some(t) = line.strip_prefix("event:") {
668            event_type = t.trim().to_string();
669        } else if let Some(d) = line.strip_prefix("data:") {
670            data_lines.push(d.trim().to_string());
671        }
672    }
673    (event_type, data_lines.join("\n"))
674}
675
676/// Build a reqwest Client, optionally with a default `X-Api-Key` header.
677fn build_client(api_key: Option<&str>) -> Client {
678    let mut builder = Client::builder();
679    if let Some(key) = api_key {
680        let mut headers = reqwest::header::HeaderMap::new();
681        match reqwest::header::HeaderValue::from_str(key) {
682            Ok(val) => {
683                headers.insert("x-api-key", val);
684            }
685            Err(_) => {
686                eprintln!(
687                    "warning: MOLD_API_KEY contains characters invalid for an HTTP header; \
688                     authentication header will not be sent"
689                );
690            }
691        }
692        builder = builder.default_headers(headers);
693    }
694    builder.build().unwrap_or_else(|_| Client::new())
695}
696
697/// Normalize a host string into a full URL.
698///
699/// Accepts:
700/// - Bare hostname: `hal9000` → `http://hal9000:7680`
701/// - Host with port: `hal9000:8080` → `http://hal9000:8080`
702/// - Full URL: `http://hal9000:7680` → unchanged
703/// - URL without port: `http://hal9000` → unchanged (uses scheme default 80/443)
704pub fn normalize_host(input: &str) -> String {
705    let trimmed = input.trim().trim_end_matches('/');
706    if trimmed.contains("://") {
707        trimmed.to_string()
708    } else if trimmed.contains(':') {
709        format!("http://{trimmed}")
710    } else {
711        format!("http://{trimmed}:7680")
712    }
713}
714
715/// Error indicating a model was not found on the server (404 with body).
716/// Detected by [`MoldClient::is_model_not_found`].
717#[derive(Debug)]
718pub struct ModelNotFoundError(pub String);
719
720impl std::fmt::Display for ModelNotFoundError {
721    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
722        write!(f, "{}", self.0)
723    }
724}
725
726impl std::error::Error for ModelNotFoundError {}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731    use crate::test_support::ENV_LOCK;
732
733    #[test]
734    fn test_new_trims_trailing_slash() {
735        let client = MoldClient::new("http://localhost:7680/");
736        assert_eq!(client.host(), "http://localhost:7680");
737    }
738
739    #[test]
740    fn test_new_no_slash_unchanged() {
741        let client = MoldClient::new("http://localhost:7680");
742        assert_eq!(client.host(), "http://localhost:7680");
743    }
744
745    #[test]
746    fn test_new_multiple_slashes() {
747        let client = MoldClient::new("http://localhost:7680///");
748        assert_eq!(client.host(), "http://localhost:7680");
749    }
750
751    #[test]
752    fn test_from_env_mold_host() {
753        let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
754        // Single test to avoid env var races between parallel tests
755        unsafe { std::env::remove_var("MOLD_HOST") };
756        let client = MoldClient::from_env();
757        assert_eq!(client.host(), "http://localhost:7680");
758
759        let unique_url = "http://test-host-env:9999";
760        unsafe { std::env::set_var("MOLD_HOST", unique_url) };
761        let client = MoldClient::from_env();
762        assert_eq!(client.host(), unique_url);
763        unsafe { std::env::remove_var("MOLD_HOST") };
764    }
765
766    #[test]
767    fn test_is_connection_error_non_connect() {
768        // A generic anyhow error is not a connection error
769        let err = anyhow::anyhow!("something went wrong");
770        assert!(!MoldClient::is_connection_error(&err));
771    }
772
773    #[test]
774    fn test_is_model_not_found_via_custom_error() {
775        let err: anyhow::Error =
776            ModelNotFoundError("model 'test' is not downloaded".to_string()).into();
777        assert!(MoldClient::is_model_not_found(&err));
778    }
779
780    #[test]
781    fn test_is_model_not_found_generic_error() {
782        let err = anyhow::anyhow!("something else");
783        assert!(!MoldClient::is_model_not_found(&err));
784    }
785
786    #[test]
787    fn test_normalize_bare_hostname() {
788        let client = MoldClient::new("hal9000");
789        assert_eq!(client.host(), "http://hal9000:7680");
790    }
791
792    #[test]
793    fn test_normalize_hostname_with_port() {
794        let client = MoldClient::new("hal9000:8080");
795        assert_eq!(client.host(), "http://hal9000:8080");
796    }
797
798    #[test]
799    fn test_normalize_full_url_unchanged() {
800        let client = MoldClient::new("http://hal9000:7680");
801        assert_eq!(client.host(), "http://hal9000:7680");
802    }
803
804    #[test]
805    fn test_normalize_https_no_port() {
806        let client = MoldClient::new("https://hal9000");
807        assert_eq!(client.host(), "https://hal9000");
808    }
809
810    #[test]
811    fn test_normalize_http_no_port() {
812        let client = MoldClient::new("http://hal9000");
813        assert_eq!(client.host(), "http://hal9000");
814    }
815
816    #[test]
817    fn test_normalize_localhost() {
818        let client = MoldClient::new("localhost");
819        assert_eq!(client.host(), "http://localhost:7680");
820    }
821
822    #[test]
823    fn test_normalize_whitespace_trimmed() {
824        let client = MoldClient::new("  hal9000  ");
825        assert_eq!(client.host(), "http://hal9000:7680");
826    }
827
828    #[test]
829    fn test_normalize_ip_address() {
830        let client = MoldClient::new("192.168.1.100");
831        assert_eq!(client.host(), "http://192.168.1.100:7680");
832    }
833
834    #[test]
835    fn test_normalize_ip_with_port() {
836        let client = MoldClient::new("192.168.1.100:9090");
837        assert_eq!(client.host(), "http://192.168.1.100:9090");
838    }
839
840    #[test]
841    fn test_is_model_not_found_via_mold_error() {
842        let err: anyhow::Error =
843            MoldError::ModelNotFound("model 'test' is not downloaded".to_string()).into();
844        assert!(MoldClient::is_model_not_found(&err));
845    }
846
847    #[test]
848    fn test_is_connection_error_via_mold_error() {
849        let err: anyhow::Error = MoldError::Client("connection refused".to_string()).into();
850        assert!(MoldClient::is_connection_error(&err));
851    }
852
853    #[test]
854    fn parse_sse_event_joins_multiline_data() {
855        let (event_type, data) =
856            parse_sse_event("event: progress\ndata: {\"a\":1}\ndata: {\"b\":2}");
857        assert_eq!(event_type, "progress");
858        assert_eq!(data, "{\"a\":1}\n{\"b\":2}");
859    }
860
861    #[test]
862    fn next_sse_event_supports_crlf_delimiters() {
863        let mut buffer = "event: progress\r\ndata: {\"ok\":true}\r\n\r\nrest".to_string();
864        let event = next_sse_event(&mut buffer).expect("expected one event");
865        assert!(event.contains("event: progress"));
866        assert_eq!(buffer, "rest");
867    }
868
869    // ── Video header parsing tests ───────────────────────────────────────
870
871    #[test]
872    fn parse_video_headers_returns_none_without_frames() {
873        let headers = reqwest::header::HeaderMap::new();
874        assert!(parse_video_headers(&headers).is_none());
875    }
876
877    #[test]
878    fn parse_video_headers_returns_some_with_frames() {
879        let mut headers = reqwest::header::HeaderMap::new();
880        headers.insert("x-mold-video-frames", "33".parse().unwrap());
881        headers.insert("x-mold-video-fps", "12".parse().unwrap());
882        headers.insert("x-mold-video-width", "832".parse().unwrap());
883        headers.insert("x-mold-video-height", "480".parse().unwrap());
884
885        let meta = parse_video_headers(&headers).expect("should detect video");
886        assert_eq!(meta.frames, 33);
887        assert_eq!(meta.fps, 12);
888        assert_eq!(meta.width, Some(832));
889        assert_eq!(meta.height, Some(480));
890        assert!(!meta.has_audio);
891        assert!(meta.duration_ms.is_none());
892    }
893
894    #[test]
895    fn parse_video_headers_with_audio_metadata() {
896        let mut headers = reqwest::header::HeaderMap::new();
897        headers.insert("x-mold-video-frames", "17".parse().unwrap());
898        headers.insert("x-mold-video-fps", "24".parse().unwrap());
899        headers.insert("x-mold-video-has-audio", "1".parse().unwrap());
900        headers.insert("x-mold-video-duration-ms", "2750".parse().unwrap());
901        headers.insert("x-mold-video-audio-sample-rate", "44100".parse().unwrap());
902        headers.insert("x-mold-video-audio-channels", "2".parse().unwrap());
903
904        let meta = parse_video_headers(&headers).expect("should detect video");
905        assert_eq!(meta.frames, 17);
906        assert_eq!(meta.fps, 24);
907        assert!(meta.has_audio);
908        assert_eq!(meta.duration_ms, Some(2750));
909        assert_eq!(meta.audio_sample_rate, Some(44100));
910        assert_eq!(meta.audio_channels, Some(2));
911    }
912
913    #[test]
914    fn parse_video_headers_fps_defaults_to_24() {
915        let mut headers = reqwest::header::HeaderMap::new();
916        headers.insert("x-mold-video-frames", "10".parse().unwrap());
917        // No fps header — should default to 24
918
919        let meta = parse_video_headers(&headers).expect("should detect video");
920        assert_eq!(meta.fps, 24);
921    }
922
923    #[test]
924    fn parse_video_headers_has_audio_absent_is_false() {
925        let mut headers = reqwest::header::HeaderMap::new();
926        headers.insert("x-mold-video-frames", "10".parse().unwrap());
927        // No has-audio header
928
929        let meta = parse_video_headers(&headers).expect("should detect video");
930        assert!(!meta.has_audio);
931    }
932}