use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, instrument, warn};
use std::path::{Path, PathBuf};
use base64::{engine::general_purpose::STANDARD as Base64Standard, Engine};
use tokio::fs;
use super::error::OcrOutputError;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub struct OcrRequest {
pub model: String,
pub document: DocumentInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pages: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_image_base64: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_min_size: Option<u32>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DocumentInput {
DocumentUrl { document_url: String },
ImageUrl { image_url: String },
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct OcrResponse {
pub pages: Vec<PageDetail>,
pub model: String,
pub usage_info: UsageInfo,
}
impl OcrResponse {
#[instrument(skip(self), fields(output_dir = %output_dir.as_ref().display()))]
pub async fn save_to_files(
&self,
output_dir: impl AsRef<Path>,
) -> Result<(), OcrOutputError> {
let output_dir = output_dir.as_ref();
if output_dir.is_file() {
return Err(OcrOutputError::InvalidOutputPath(output_dir.to_path_buf()));
}
fs::create_dir_all(&output_dir).await?;
debug!(target: "mistral_ocr::output", path = %output_dir.display(), "Ensured output directory exists");
let md_path = output_dir.join("output.md");
let mut combined_markdown = String::new();
let mut sorted_pages = self.pages.clone(); sorted_pages.sort_by_key(|p| p.index);
for (i, page) in sorted_pages.iter().enumerate() {
if i > 0 {
combined_markdown.push_str("\n\n---\n\n"); }
combined_markdown.push_str(&page.markdown);
}
fs::write(&md_path, combined_markdown).await?;
debug!(target: "mistral_ocr::output", path = %md_path.display(), "Wrote combined markdown");
let mut images_dir = None;
let mut image_count = 0;
for page in &sorted_pages {
for image in &page.images {
if !image.image_base64.is_empty() {
if images_dir.is_none() {
let images_dir_path = output_dir.join("images");
fs::create_dir_all(&images_dir_path).await?;
debug!(target: "mistral_ocr::output", path = %images_dir_path.display(), "Ensured images directory exists");
images_dir = Some(images_dir_path);
}
let base64_data_to_decode = if image.image_base64.starts_with("data:") {
if let Some(comma_index) = image.image_base64.find(',') {
let data_part = &image.image_base64[comma_index + 1..];
debug!(target: "mistral_ocr::output", image_id = %image.id, "Detected and stripped Data URI prefix");
data_part
} else {
warn!(target: "mistral_ocr::output", image_id = %image.id, "Found 'data:' prefix but no comma, attempting decode anyway");
&image.image_base64
}
} else {
&image.image_base64
};
let image_data = Base64Standard.decode(base64_data_to_decode).map_err(|e| {
error!(target: "mistral_ocr::output", image_id = %image.id, error = %e, "Base64 decoding failed");
OcrOutputError::Base64Decode{ image_id: image.id.clone(), source: e }
})?;
let image_path = images_dir.as_ref().unwrap().join(&image.id);
fs::write(&image_path, image_data).await?;
image_count += 1;
debug!(target: "mistral_ocr::output", path = %image_path.display(), "Wrote image");
}
}
}
info!(target: "mistral_ocr::output", markdown_path = %md_path.display(), images_saved = image_count, "OCR output saved successfully");
Ok(())
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PageDetail {
pub index: u32,
pub markdown: String,
pub images: Vec<ImageDetail>,
pub dimensions: PageDimensions,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ImageDetail {
pub id: String,
pub top_left_x: u32,
pub top_left_y: u32,
pub bottom_right_x: u32,
pub bottom_right_y: u32,
pub image_base64: String, }
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PageDimensions {
pub dpi: u32,
pub height: u32,
pub width: u32,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct UsageInfo {
pub pages_processed: u32,
pub doc_size_bytes: Option<u64>, }
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FileUploadResponse {
pub id: String,
pub object: String, pub bytes: u64,
pub created_at: u64, pub filename: String,
pub purpose: String,
pub sample_type: Option<String>, pub num_lines: Option<u64>, pub source: String, }
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ValidationErrorDetail {
pub loc: Vec<serde_json::Value>,
pub msg: String,
#[serde(rename = "type")] pub error_type: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct HttpValidationErrorResponse {
pub detail: Vec<ValidationErrorDetail>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SignedUrlResponse {
pub url: String, }