1use crate::error::MoldError;
2use crate::types::{
3 ExpandRequest, ExpandResponse, GalleryImage, GenerateRequest, GenerateResponse, ImageData,
4 ModelInfo, ModelInfoExtended, ServerStatus, SseCompleteEvent, SseErrorEvent, SseProgressEvent,
5 VideoData,
6};
7use anyhow::Result;
8use base64::Engine as _;
9use reqwest::Client;
10
11pub struct MoldClient {
12 base_url: String,
13 client: Client,
14}
15
16impl MoldClient {
17 pub fn new(base_url: &str) -> Self {
18 let client = build_client(None);
19 Self {
20 base_url: normalize_host(base_url),
21 client,
22 }
23 }
24
25 pub fn with_api_key(base_url: &str, api_key: String) -> Self {
27 let client = build_client(Some(&api_key));
28 Self {
29 base_url: normalize_host(base_url),
30 client,
31 }
32 }
33
34 pub fn from_env() -> Self {
35 let base_url =
36 std::env::var("MOLD_HOST").unwrap_or_else(|_| "http://localhost:7680".to_string());
37 let api_key = std::env::var("MOLD_API_KEY").ok().filter(|k| !k.is_empty());
38 let client = build_client(api_key.as_deref());
39 Self {
40 base_url: normalize_host(&base_url),
41 client,
42 }
43 }
44
45 pub async fn generate_raw(&self, req: &GenerateRequest) -> Result<Vec<u8>> {
49 let bytes = self
50 .client
51 .post(format!("{}/api/generate", self.base_url))
52 .json(req)
53 .send()
54 .await?
55 .error_for_status()?
56 .bytes()
57 .await?
58 .to_vec();
59 Ok(bytes)
60 }
61
62 pub async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse> {
67 let fallback_seed = req.seed.unwrap_or(0);
68 let width = req.width;
69 let height = req.height;
70 let model = req.model.clone();
71 let format = req.output_format;
72
73 let start = std::time::Instant::now();
74 let resp = self
75 .client
76 .post(format!("{}/api/generate", self.base_url))
77 .json(&req)
78 .send()
79 .await?
80 .error_for_status()?;
81
82 let seed_used = resp
85 .headers()
86 .get("x-mold-seed-used")
87 .and_then(|v| v.to_str().ok())
88 .and_then(|s| s.parse::<u64>().ok())
89 .unwrap_or(fallback_seed);
90
91 let video_meta = parse_video_headers(resp.headers());
93
94 let data = resp.bytes().await?.to_vec();
95 let generation_time_ms = start.elapsed().as_millis() as u64;
96
97 let video = video_meta.map(|meta| VideoData {
98 data: data.clone(),
99 format,
100 width: meta.width.unwrap_or(width),
101 height: meta.height.unwrap_or(height),
102 frames: meta.frames,
103 fps: meta.fps,
104 thumbnail: Vec::new(),
105 gif_preview: Vec::new(),
106 has_audio: meta.has_audio,
107 duration_ms: meta.duration_ms,
108 audio_sample_rate: meta.audio_sample_rate,
109 audio_channels: meta.audio_channels,
110 });
111
112 let images = if video.is_some() {
114 Vec::new()
115 } else {
116 vec![ImageData {
117 data,
118 format,
119 width,
120 height,
121 index: 0,
122 }]
123 };
124
125 Ok(GenerateResponse {
126 images,
127 generation_time_ms,
128 model,
129 seed_used,
130 video,
131 })
132 }
133
134 pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
135 let models = self.list_models_extended().await?;
136 Ok(models.into_iter().map(|m| m.info).collect())
137 }
138
139 pub async fn list_models_extended(&self) -> Result<Vec<ModelInfoExtended>> {
140 let resp = self
141 .client
142 .get(format!("{}/api/models", self.base_url))
143 .send()
144 .await?
145 .error_for_status()?
146 .json::<Vec<ModelInfoExtended>>()
147 .await?;
148 Ok(resp)
149 }
150
151 pub fn is_connection_error(err: &anyhow::Error) -> bool {
154 if let Some(mold_err) = err.downcast_ref::<MoldError>() {
156 if matches!(mold_err, MoldError::Client(_)) {
157 return true;
158 }
159 }
160 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
161 return reqwest_err.is_connect();
162 }
163 false
164 }
165
166 pub fn is_model_not_found(err: &anyhow::Error) -> bool {
169 if let Some(mold_err) = err.downcast_ref::<MoldError>() {
171 if matches!(mold_err, MoldError::ModelNotFound(_)) {
172 return true;
173 }
174 }
175 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
176 return reqwest_err.status() == Some(reqwest::StatusCode::NOT_FOUND);
177 }
178 err.downcast_ref::<ModelNotFoundError>().is_some()
180 }
181
182 pub async fn generate_stream(
189 &self,
190 req: &GenerateRequest,
191 progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
192 ) -> Result<Option<GenerateResponse>> {
193 let mut resp = self
194 .client
195 .post(format!("{}/api/generate/stream", self.base_url))
196 .json(req)
197 .send()
198 .await?;
199
200 if resp.status() == reqwest::StatusCode::NOT_FOUND {
201 let body = resp.text().await.unwrap_or_default();
202 if body.is_empty() {
203 return Ok(None);
205 }
206 return Err(MoldError::ModelNotFound(body).into());
208 }
209
210 if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY {
211 let body = resp.text().await.unwrap_or_default();
212 return Err(MoldError::Validation(format!("validation error: {body}")).into());
213 }
214
215 if resp.status().is_client_error() || resp.status().is_server_error() {
216 let status = resp.status();
217 let body = resp.text().await.unwrap_or_default();
218 anyhow::bail!("server error {status}: {body}");
219 }
220
221 let mut buffer = String::new();
223 while let Some(chunk) = resp.chunk().await? {
224 buffer.push_str(&String::from_utf8_lossy(&chunk));
225
226 while let Some(event_text) = next_sse_event(&mut buffer) {
227 let (event_type, data) = parse_sse_event(&event_text);
228 match event_type.as_str() {
229 "progress" => {
230 if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
231 let _ = progress_tx.send(p);
232 }
233 }
234 "complete" => {
235 let complete: SseCompleteEvent = serde_json::from_str(&data)?;
236 let payload =
237 base64::engine::general_purpose::STANDARD.decode(&complete.image)?;
238 let b64 = base64::engine::general_purpose::STANDARD;
239 let model = if complete.model.is_empty() {
243 req.model.clone()
244 } else {
245 complete.model
246 };
247
248 let (images, video) = if let (Some(frames), Some(fps)) =
250 (complete.video_frames, complete.video_fps)
251 {
252 let thumbnail = complete
253 .video_thumbnail
254 .as_deref()
255 .and_then(|s| b64.decode(s).ok())
256 .unwrap_or_default();
257 let gif_preview = complete
258 .video_gif_preview
259 .as_deref()
260 .and_then(|s| b64.decode(s).ok())
261 .unwrap_or_default();
262 let vd = VideoData {
263 data: payload,
264 format: complete.format,
265 width: complete.width,
266 height: complete.height,
267 frames,
268 fps,
269 thumbnail,
270 gif_preview,
271 has_audio: complete.video_has_audio,
272 duration_ms: complete.video_duration_ms,
273 audio_sample_rate: complete.video_audio_sample_rate,
274 audio_channels: complete.video_audio_channels,
275 };
276 (Vec::new(), Some(vd))
277 } else {
278 let img = ImageData {
279 data: payload,
280 format: complete.format,
281 width: complete.width,
282 height: complete.height,
283 index: 0,
284 };
285 (vec![img], None)
286 };
287
288 return Ok(Some(GenerateResponse {
289 images,
290 generation_time_ms: complete.generation_time_ms,
291 model,
292 seed_used: complete.seed_used,
293 video,
294 }));
295 }
296 "error" => {
297 let error: SseErrorEvent = serde_json::from_str(&data)?;
298 anyhow::bail!("server error: {}", error.message);
299 }
300 _ => {}
301 }
302 }
303 }
304
305 anyhow::bail!("SSE stream ended without complete event")
306 }
307
308 pub async fn pull_model(&self, model: &str) -> Result<String> {
312 let resp = self
313 .client
314 .post(format!("{}/api/models/pull", self.base_url))
315 .json(&serde_json::json!({ "model": model }))
316 .send()
317 .await?
318 .error_for_status()?
319 .text()
320 .await?;
321 Ok(resp)
322 }
323
324 pub async fn shutdown_server(&self) -> Result<()> {
326 self.client
327 .post(format!("{}/api/shutdown", self.base_url))
328 .send()
329 .await?
330 .error_for_status()?;
331 Ok(())
332 }
333
334 pub async fn pull_model_stream(
339 &self,
340 model: &str,
341 progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
342 ) -> Result<()> {
343 let mut resp = self
344 .client
345 .post(format!("{}/api/models/pull", self.base_url))
346 .header("Accept", "text/event-stream")
347 .json(&serde_json::json!({ "model": model }))
348 .send()
349 .await?;
350
351 if resp.status().is_client_error() || resp.status().is_server_error() {
352 let status = resp.status();
353 let body = resp.text().await.unwrap_or_default();
354 anyhow::bail!("server error {status}: {body}");
355 }
356
357 let content_type = resp
359 .headers()
360 .get("content-type")
361 .and_then(|v| v.to_str().ok())
362 .unwrap_or("");
363
364 if !content_type.contains("text/event-stream") {
365 drop(progress_tx);
368 let _ = resp.text().await?;
369 return Ok(());
370 }
371
372 let mut buffer = String::new();
374 while let Some(chunk) = resp.chunk().await? {
375 buffer.push_str(&String::from_utf8_lossy(&chunk));
376
377 while let Some(event_text) = next_sse_event(&mut buffer) {
378 let (event_type, data) = parse_sse_event(&event_text);
379 match event_type.as_str() {
380 "progress" => {
381 if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
382 let is_done = matches!(p, SseProgressEvent::PullComplete { .. });
384 let _ = progress_tx.send(p);
385 if is_done {
386 return Ok(());
387 }
388 }
389 }
390 "error" => {
391 let error: SseErrorEvent = serde_json::from_str(&data)?;
392 anyhow::bail!("server error: {}", error.message);
393 }
394 _ => {}
395 }
396 }
397 }
398
399 Ok(())
400 }
401
402 pub fn host(&self) -> &str {
403 &self.base_url
404 }
405
406 pub async fn unload_model(&self) -> Result<String> {
407 let resp = self
408 .client
409 .delete(format!("{}/api/models/unload", self.base_url))
410 .send()
411 .await?
412 .error_for_status()?
413 .text()
414 .await?;
415 Ok(resp)
416 }
417
418 pub async fn server_status(&self) -> Result<ServerStatus> {
419 let resp = self
420 .client
421 .get(format!("{}/api/status", self.base_url))
422 .send()
423 .await?
424 .error_for_status()?
425 .json::<ServerStatus>()
426 .await?;
427 Ok(resp)
428 }
429
430 pub async fn list_gallery(&self) -> Result<Vec<GalleryImage>> {
432 let resp = self
433 .client
434 .get(format!("{}/api/gallery", self.base_url))
435 .send()
436 .await?
437 .error_for_status()?
438 .json::<Vec<GalleryImage>>()
439 .await?;
440 Ok(resp)
441 }
442
443 pub async fn get_gallery_image(&self, filename: &str) -> Result<Vec<u8>> {
445 let resp = self
446 .client
447 .get(format!("{}/api/gallery/image/{filename}", self.base_url))
448 .send()
449 .await?
450 .error_for_status()?
451 .bytes()
452 .await?;
453 Ok(resp.to_vec())
454 }
455
456 pub async fn delete_gallery_image(&self, filename: &str) -> Result<()> {
458 self.client
459 .delete(format!("{}/api/gallery/image/{filename}", self.base_url))
460 .send()
461 .await?
462 .error_for_status()?;
463 Ok(())
464 }
465
466 pub async fn get_gallery_thumbnail(&self, filename: &str) -> Result<Vec<u8>> {
468 let resp = self
469 .client
470 .get(format!(
471 "{}/api/gallery/thumbnail/{filename}",
472 self.base_url
473 ))
474 .send()
475 .await?
476 .error_for_status()?
477 .bytes()
478 .await?;
479 Ok(resp.to_vec())
480 }
481
482 pub async fn expand_prompt(&self, req: &ExpandRequest) -> Result<ExpandResponse> {
484 let resp = self
485 .client
486 .post(format!("{}/api/expand", self.base_url))
487 .json(req)
488 .send()
489 .await?
490 .error_for_status()?
491 .json::<ExpandResponse>()
492 .await?;
493 Ok(resp)
494 }
495
496 pub async fn upscale(&self, req: &crate::UpscaleRequest) -> Result<crate::UpscaleResponse> {
498 let resp = self
499 .client
500 .post(format!("{}/api/upscale", self.base_url))
501 .json(req)
502 .send()
503 .await?
504 .error_for_status()?
505 .json::<crate::UpscaleResponse>()
506 .await?;
507 Ok(resp)
508 }
509
510 pub async fn upscale_stream(
513 &self,
514 req: &crate::UpscaleRequest,
515 progress_tx: tokio::sync::mpsc::UnboundedSender<SseProgressEvent>,
516 ) -> Result<Option<crate::UpscaleResponse>> {
517 let mut resp = self
518 .client
519 .post(format!("{}/api/upscale/stream", self.base_url))
520 .json(req)
521 .send()
522 .await?;
523
524 if resp.status() == reqwest::StatusCode::NOT_FOUND {
525 let body = resp.text().await.unwrap_or_default();
526 if body.is_empty() {
527 return Ok(None); }
529 return Err(MoldError::ModelNotFound(body).into());
530 }
531
532 if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY {
533 let body = resp.text().await.unwrap_or_default();
534 return Err(MoldError::Validation(format!("validation error: {body}")).into());
535 }
536
537 if resp.status().is_client_error() || resp.status().is_server_error() {
538 let status = resp.status();
539 let body = resp.text().await.unwrap_or_default();
540 anyhow::bail!("server error {status}: {body}");
541 }
542
543 let mut buffer = String::new();
544 while let Some(chunk) = resp.chunk().await? {
545 buffer.push_str(&String::from_utf8_lossy(&chunk));
546
547 while let Some(event_text) = next_sse_event(&mut buffer) {
548 let (event_type, data) = parse_sse_event(&event_text);
549 match event_type.as_str() {
550 "progress" => {
551 if let Ok(p) = serde_json::from_str::<SseProgressEvent>(&data) {
552 let _ = progress_tx.send(p);
553 }
554 }
555 "complete" => {
556 let complete: crate::SseUpscaleCompleteEvent = serde_json::from_str(&data)?;
557 let image_data =
558 base64::engine::general_purpose::STANDARD.decode(&complete.image)?;
559 return Ok(Some(crate::UpscaleResponse {
560 image: crate::ImageData {
561 data: image_data,
562 format: complete.format,
563 width: complete.original_width * complete.scale_factor,
564 height: complete.original_height * complete.scale_factor,
565 index: 0,
566 },
567 upscale_time_ms: complete.upscale_time_ms,
568 model: complete.model,
569 scale_factor: complete.scale_factor,
570 original_width: complete.original_width,
571 original_height: complete.original_height,
572 }));
573 }
574 "error" => {
575 let error: crate::SseErrorEvent = serde_json::from_str(&data)?;
576 anyhow::bail!("server error: {}", error.message);
577 }
578 _ => {}
579 }
580 }
581 }
582
583 anyhow::bail!("SSE stream ended without complete event")
584 }
585}
586
587struct VideoMeta {
589 frames: u32,
590 fps: u32,
591 width: Option<u32>,
592 height: Option<u32>,
593 has_audio: bool,
594 duration_ms: Option<u64>,
595 audio_sample_rate: Option<u32>,
596 audio_channels: Option<u32>,
597}
598
599fn parse_video_headers(headers: &reqwest::header::HeaderMap) -> Option<VideoMeta> {
602 let frames = headers
603 .get("x-mold-video-frames")
604 .and_then(|v| v.to_str().ok())
605 .and_then(|s| s.parse::<u32>().ok())?;
606 let fps = headers
607 .get("x-mold-video-fps")
608 .and_then(|v| v.to_str().ok())
609 .and_then(|s| s.parse::<u32>().ok())
610 .unwrap_or(24);
611 let width = headers
612 .get("x-mold-video-width")
613 .and_then(|v| v.to_str().ok())
614 .and_then(|s| s.parse::<u32>().ok());
615 let height = headers
616 .get("x-mold-video-height")
617 .and_then(|v| v.to_str().ok())
618 .and_then(|s| s.parse::<u32>().ok());
619 let has_audio = headers
620 .get("x-mold-video-has-audio")
621 .and_then(|v| v.to_str().ok())
622 .map(|s| s == "1")
623 .unwrap_or(false);
624 let duration_ms = headers
625 .get("x-mold-video-duration-ms")
626 .and_then(|v| v.to_str().ok())
627 .and_then(|s| s.parse::<u64>().ok());
628 let audio_sample_rate = headers
629 .get("x-mold-video-audio-sample-rate")
630 .and_then(|v| v.to_str().ok())
631 .and_then(|s| s.parse::<u32>().ok());
632 let audio_channels = headers
633 .get("x-mold-video-audio-channels")
634 .and_then(|v| v.to_str().ok())
635 .and_then(|s| s.parse::<u32>().ok());
636
637 Some(VideoMeta {
638 frames,
639 fps,
640 width,
641 height,
642 has_audio,
643 duration_ms,
644 audio_sample_rate,
645 audio_channels,
646 })
647}
648
649fn next_sse_event(buffer: &mut String) -> Option<String> {
650 for separator in ["\r\n\r\n", "\n\n"] {
651 if let Some(pos) = buffer.find(separator) {
652 let event_text = buffer[..pos].to_string();
653 *buffer = buffer[pos + separator.len()..].to_string();
654 return Some(event_text);
655 }
656 }
657 None
658}
659
660fn parse_sse_event(event_text: &str) -> (String, String) {
661 let mut event_type = String::new();
662 let mut data_lines = Vec::new();
663 for line in event_text.lines() {
664 if line.starts_with(':') {
665 continue;
666 }
667 if let Some(t) = line.strip_prefix("event:") {
668 event_type = t.trim().to_string();
669 } else if let Some(d) = line.strip_prefix("data:") {
670 data_lines.push(d.trim().to_string());
671 }
672 }
673 (event_type, data_lines.join("\n"))
674}
675
676fn build_client(api_key: Option<&str>) -> Client {
678 let mut builder = Client::builder();
679 if let Some(key) = api_key {
680 let mut headers = reqwest::header::HeaderMap::new();
681 match reqwest::header::HeaderValue::from_str(key) {
682 Ok(val) => {
683 headers.insert("x-api-key", val);
684 }
685 Err(_) => {
686 eprintln!(
687 "warning: MOLD_API_KEY contains characters invalid for an HTTP header; \
688 authentication header will not be sent"
689 );
690 }
691 }
692 builder = builder.default_headers(headers);
693 }
694 builder.build().unwrap_or_else(|_| Client::new())
695}
696
697pub fn normalize_host(input: &str) -> String {
705 let trimmed = input.trim().trim_end_matches('/');
706 if trimmed.contains("://") {
707 trimmed.to_string()
708 } else if trimmed.contains(':') {
709 format!("http://{trimmed}")
710 } else {
711 format!("http://{trimmed}:7680")
712 }
713}
714
715#[derive(Debug)]
718pub struct ModelNotFoundError(pub String);
719
720impl std::fmt::Display for ModelNotFoundError {
721 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
722 write!(f, "{}", self.0)
723 }
724}
725
726impl std::error::Error for ModelNotFoundError {}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731 use crate::test_support::ENV_LOCK;
732
733 #[test]
734 fn test_new_trims_trailing_slash() {
735 let client = MoldClient::new("http://localhost:7680/");
736 assert_eq!(client.host(), "http://localhost:7680");
737 }
738
739 #[test]
740 fn test_new_no_slash_unchanged() {
741 let client = MoldClient::new("http://localhost:7680");
742 assert_eq!(client.host(), "http://localhost:7680");
743 }
744
745 #[test]
746 fn test_new_multiple_slashes() {
747 let client = MoldClient::new("http://localhost:7680///");
748 assert_eq!(client.host(), "http://localhost:7680");
749 }
750
751 #[test]
752 fn test_from_env_mold_host() {
753 let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
754 unsafe { std::env::remove_var("MOLD_HOST") };
756 let client = MoldClient::from_env();
757 assert_eq!(client.host(), "http://localhost:7680");
758
759 let unique_url = "http://test-host-env:9999";
760 unsafe { std::env::set_var("MOLD_HOST", unique_url) };
761 let client = MoldClient::from_env();
762 assert_eq!(client.host(), unique_url);
763 unsafe { std::env::remove_var("MOLD_HOST") };
764 }
765
766 #[test]
767 fn test_is_connection_error_non_connect() {
768 let err = anyhow::anyhow!("something went wrong");
770 assert!(!MoldClient::is_connection_error(&err));
771 }
772
773 #[test]
774 fn test_is_model_not_found_via_custom_error() {
775 let err: anyhow::Error =
776 ModelNotFoundError("model 'test' is not downloaded".to_string()).into();
777 assert!(MoldClient::is_model_not_found(&err));
778 }
779
780 #[test]
781 fn test_is_model_not_found_generic_error() {
782 let err = anyhow::anyhow!("something else");
783 assert!(!MoldClient::is_model_not_found(&err));
784 }
785
786 #[test]
787 fn test_normalize_bare_hostname() {
788 let client = MoldClient::new("hal9000");
789 assert_eq!(client.host(), "http://hal9000:7680");
790 }
791
792 #[test]
793 fn test_normalize_hostname_with_port() {
794 let client = MoldClient::new("hal9000:8080");
795 assert_eq!(client.host(), "http://hal9000:8080");
796 }
797
798 #[test]
799 fn test_normalize_full_url_unchanged() {
800 let client = MoldClient::new("http://hal9000:7680");
801 assert_eq!(client.host(), "http://hal9000:7680");
802 }
803
804 #[test]
805 fn test_normalize_https_no_port() {
806 let client = MoldClient::new("https://hal9000");
807 assert_eq!(client.host(), "https://hal9000");
808 }
809
810 #[test]
811 fn test_normalize_http_no_port() {
812 let client = MoldClient::new("http://hal9000");
813 assert_eq!(client.host(), "http://hal9000");
814 }
815
816 #[test]
817 fn test_normalize_localhost() {
818 let client = MoldClient::new("localhost");
819 assert_eq!(client.host(), "http://localhost:7680");
820 }
821
822 #[test]
823 fn test_normalize_whitespace_trimmed() {
824 let client = MoldClient::new(" hal9000 ");
825 assert_eq!(client.host(), "http://hal9000:7680");
826 }
827
828 #[test]
829 fn test_normalize_ip_address() {
830 let client = MoldClient::new("192.168.1.100");
831 assert_eq!(client.host(), "http://192.168.1.100:7680");
832 }
833
834 #[test]
835 fn test_normalize_ip_with_port() {
836 let client = MoldClient::new("192.168.1.100:9090");
837 assert_eq!(client.host(), "http://192.168.1.100:9090");
838 }
839
840 #[test]
841 fn test_is_model_not_found_via_mold_error() {
842 let err: anyhow::Error =
843 MoldError::ModelNotFound("model 'test' is not downloaded".to_string()).into();
844 assert!(MoldClient::is_model_not_found(&err));
845 }
846
847 #[test]
848 fn test_is_connection_error_via_mold_error() {
849 let err: anyhow::Error = MoldError::Client("connection refused".to_string()).into();
850 assert!(MoldClient::is_connection_error(&err));
851 }
852
853 #[test]
854 fn parse_sse_event_joins_multiline_data() {
855 let (event_type, data) =
856 parse_sse_event("event: progress\ndata: {\"a\":1}\ndata: {\"b\":2}");
857 assert_eq!(event_type, "progress");
858 assert_eq!(data, "{\"a\":1}\n{\"b\":2}");
859 }
860
861 #[test]
862 fn next_sse_event_supports_crlf_delimiters() {
863 let mut buffer = "event: progress\r\ndata: {\"ok\":true}\r\n\r\nrest".to_string();
864 let event = next_sse_event(&mut buffer).expect("expected one event");
865 assert!(event.contains("event: progress"));
866 assert_eq!(buffer, "rest");
867 }
868
869 #[test]
872 fn parse_video_headers_returns_none_without_frames() {
873 let headers = reqwest::header::HeaderMap::new();
874 assert!(parse_video_headers(&headers).is_none());
875 }
876
877 #[test]
878 fn parse_video_headers_returns_some_with_frames() {
879 let mut headers = reqwest::header::HeaderMap::new();
880 headers.insert("x-mold-video-frames", "33".parse().unwrap());
881 headers.insert("x-mold-video-fps", "12".parse().unwrap());
882 headers.insert("x-mold-video-width", "832".parse().unwrap());
883 headers.insert("x-mold-video-height", "480".parse().unwrap());
884
885 let meta = parse_video_headers(&headers).expect("should detect video");
886 assert_eq!(meta.frames, 33);
887 assert_eq!(meta.fps, 12);
888 assert_eq!(meta.width, Some(832));
889 assert_eq!(meta.height, Some(480));
890 assert!(!meta.has_audio);
891 assert!(meta.duration_ms.is_none());
892 }
893
894 #[test]
895 fn parse_video_headers_with_audio_metadata() {
896 let mut headers = reqwest::header::HeaderMap::new();
897 headers.insert("x-mold-video-frames", "17".parse().unwrap());
898 headers.insert("x-mold-video-fps", "24".parse().unwrap());
899 headers.insert("x-mold-video-has-audio", "1".parse().unwrap());
900 headers.insert("x-mold-video-duration-ms", "2750".parse().unwrap());
901 headers.insert("x-mold-video-audio-sample-rate", "44100".parse().unwrap());
902 headers.insert("x-mold-video-audio-channels", "2".parse().unwrap());
903
904 let meta = parse_video_headers(&headers).expect("should detect video");
905 assert_eq!(meta.frames, 17);
906 assert_eq!(meta.fps, 24);
907 assert!(meta.has_audio);
908 assert_eq!(meta.duration_ms, Some(2750));
909 assert_eq!(meta.audio_sample_rate, Some(44100));
910 assert_eq!(meta.audio_channels, Some(2));
911 }
912
913 #[test]
914 fn parse_video_headers_fps_defaults_to_24() {
915 let mut headers = reqwest::header::HeaderMap::new();
916 headers.insert("x-mold-video-frames", "10".parse().unwrap());
917 let meta = parse_video_headers(&headers).expect("should detect video");
920 assert_eq!(meta.fps, 24);
921 }
922
923 #[test]
924 fn parse_video_headers_has_audio_absent_is_false() {
925 let mut headers = reqwest::header::HeaderMap::new();
926 headers.insert("x-mold-video-frames", "10".parse().unwrap());
927 let meta = parse_video_headers(&headers).expect("should detect video");
930 assert!(!meta.has_audio);
931 }
932}