Skip to main content

fx_mistral/
ocr.rs

1use crate::{MistralApiError, MistralClient, MistralError};
2
3use serde::{Deserialize, Serialize};
4
5pub struct OcrClient<'a> {
6    mistral_client: &'a MistralClient,
7    ocr_path: String,
8    model: String,
9}
10
11// OCR Request structs
12
13#[derive(Serialize, Deserialize, Debug)]
14struct OcrRequest {
15    model: String,
16    document: Document,
17    include_image_base64: bool,
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21struct Document {
22    #[serde(rename = "type")]
23    document_type: String,
24    document_url: String,
25}
26
27// OCR Response structs
28
29#[derive(Serialize, Deserialize, Debug)]
30pub struct OcrResponse {
31    pub pages: Vec<Page>,
32    pub model: String,
33    pub usage_info: UsageInfo,
34}
35
36#[derive(Serialize, Deserialize, Debug)]
37pub struct Page {
38    pub index: u32,
39    pub markdown: String,
40    pub images: Vec<Image>,
41    pub dimensions: Dimensions,
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct Image {
46    pub id: String,
47    pub top_left_x: u32,
48    pub top_left_y: u32,
49    pub bottom_right_x: u32,
50    pub bottom_right_y: u32,
51    pub image_base64: Option<String>,
52}
53
54#[derive(Serialize, Deserialize, Debug)]
55pub struct Dimensions {
56    pub dpi: u32,
57    pub height: u32,
58    pub width: u32,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct UsageInfo {
63    pub pages_processed: u32,
64    pub doc_size_bytes: u64,
65}
66
67impl<'a> OcrClient<'a> {
68    pub fn new(mistral_client: &'a MistralClient, model: &str) -> Self {
69        OcrClient {
70            mistral_client,
71            ocr_path: "ocr".to_string(),
72            model: model.to_string()
73        }
74    }
75
76
77    pub async fn get_ocr_results(&self, signed_url: &str) -> Result<OcrResponse, MistralError> {
78        let ocr_request = OcrRequest {
79            model: self.model.clone(),
80            document: Document {
81                document_type: "document_url".to_string(),
82                document_url: signed_url.to_string(),
83            },
84            include_image_base64: false,
85        };
86
87        let response = self
88            .mistral_client
89            .client
90            .post(&format!("{}/{}", self.mistral_client.base_url, self.ocr_path))
91            .bearer_auth(&self.mistral_client.api_key)
92            .json(&ocr_request)
93            .send()
94            .await
95            .map_err(MistralError::Network)?;
96
97        let status = response.status();
98        let response_text = response.text().await.map_err(MistralError::Network)?;
99
100        if !status.is_success() {
101            return match serde_json::from_str::<MistralApiError>(&response_text) {
102                Ok(mut err) => {
103                    err.description = crate::error_description(err.code).to_string();
104                    Err(MistralError::Api(err))
105                },
106                Err(_) => Err(MistralError::Http(status)),
107            };
108        }
109
110        serde_json::from_str::<OcrResponse>(&response_text)
111            .map_err(MistralError::Parse)
112    }
113}
114