Skip to main content

datalab_cli/commands/
extract.rs

1use crate::cache::Cache;
2use crate::client::{add_form_field, build_form_with_file, DatalabClient};
3use crate::error::{DatalabError, Result};
4use crate::output::Progress;
5use clap::Args;
6use reqwest::multipart::Form;
7use serde_json::json;
8use std::fs;
9use std::path::PathBuf;
10
11#[derive(Args, Debug)]
12pub struct ExtractArgs {
13    /// File path or URL to extract from
14    #[arg(value_name = "FILE|URL")]
15    pub input: String,
16
17    /// JSON schema file or string defining extraction fields
18    #[arg(long, value_name = "SCHEMA", help_heading = "Schema Options")]
19    pub schema: String,
20
21    /// Checkpoint ID to reuse parsed document
22    #[arg(long, value_name = "ID", help_heading = "Processing Options")]
23    pub checkpoint_id: Option<String>,
24
25    /// Processing mode: fast, balanced, accurate
26    #[arg(
27        long,
28        default_value = "fast",
29        value_name = "MODE",
30        help_heading = "Processing Options"
31    )]
32    pub mode: String,
33
34    /// Maximum pages to process
35    #[arg(long, value_name = "N", help_heading = "Processing Options")]
36    pub max_pages: Option<u32>,
37
38    /// Page range (e.g., "0-5,10")
39    #[arg(long, value_name = "RANGE", help_heading = "Processing Options")]
40    pub page_range: Option<String>,
41
42    /// Save checkpoint for reuse
43    #[arg(long, help_heading = "Processing Options")]
44    pub save_checkpoint: bool,
45
46    /// Include per-field confidence scores
47    #[arg(long, help_heading = "Output Options")]
48    pub include_scores: bool,
49
50    /// Skip local cache lookup
51    #[arg(long, help_heading = "Cache Options")]
52    pub skip_cache: bool,
53
54    /// Write result to file
55    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
56    pub output: Option<PathBuf>,
57
58    /// Request timeout in seconds
59    #[arg(
60        long,
61        default_value = "300",
62        value_name = "SECS",
63        help_heading = "Advanced Options"
64    )]
65    pub timeout: u64,
66}
67
68impl ExtractArgs {
69    fn to_cache_params(&self) -> serde_json::Value {
70        json!({
71            "schema": self.schema,
72            "checkpoint_id": self.checkpoint_id,
73            "mode": self.mode,
74            "max_pages": self.max_pages,
75            "page_range": self.page_range,
76            "save_checkpoint": self.save_checkpoint,
77            "include_scores": self.include_scores,
78        })
79    }
80
81    fn get_schema(&self) -> Result<String> {
82        let schema_path = PathBuf::from(&self.schema);
83        if schema_path.exists() {
84            Ok(fs::read_to_string(&schema_path)?)
85        } else {
86            serde_json::from_str::<serde_json::Value>(&self.schema).map_err(|_| {
87                DatalabError::InvalidInput(
88                    "Schema must be valid JSON or a path to a JSON file".to_string(),
89                )
90            })?;
91            Ok(self.schema.clone())
92        }
93    }
94
95    fn add_to_form(&self, mut form: Form, schema: &str) -> Form {
96        form = add_form_field(form, "page_schema", schema);
97        form = add_form_field(form, "mode", &self.mode);
98
99        if let Some(ref checkpoint_id) = self.checkpoint_id {
100            form = add_form_field(form, "checkpoint_id", checkpoint_id);
101        }
102        if let Some(max_pages) = self.max_pages {
103            form = add_form_field(form, "max_pages", &max_pages.to_string());
104        }
105        if let Some(ref page_range) = self.page_range {
106            form = add_form_field(form, "page_range", page_range);
107        }
108        if self.save_checkpoint {
109            form = add_form_field(form, "save_checkpoint", "true");
110        }
111        if self.include_scores {
112            form = add_form_field(form, "include_scores", "true");
113        }
114
115        form
116    }
117}
118
119pub async fn execute(args: ExtractArgs, progress: &Progress) -> Result<()> {
120    let client = DatalabClient::new(Some(args.timeout))?;
121    let cache = Cache::new()?;
122
123    let schema = args.get_schema()?;
124
125    let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
126    let file_path = if is_url {
127        None
128    } else {
129        Some(PathBuf::from(&args.input))
130    };
131
132    let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
133    progress.start("extract", file_str.as_deref());
134
135    let file_hash = if let Some(ref path) = file_path {
136        if !path.exists() {
137            return Err(DatalabError::FileNotFound(path.clone()));
138        }
139        Some(Cache::hash_file(path)?)
140    } else {
141        None
142    };
143
144    let cache_params = args.to_cache_params();
145    let cache_key = Cache::generate_key(
146        file_hash.as_deref(),
147        if is_url { Some(&args.input) } else { None },
148        "extract",
149        &cache_params,
150    );
151
152    if !args.skip_cache {
153        if let Some(cached) = cache.get(&cache_key) {
154            progress.cache_hit(&cache_key);
155            output_result(&cached, args.output.as_ref())?;
156            return Ok(());
157        }
158    }
159
160    let form = if let Some(ref path) = file_path {
161        let (form, _) = build_form_with_file(path)?;
162        args.add_to_form(form, &schema)
163    } else {
164        let form = Form::new().text("file_url", args.input.clone());
165        args.add_to_form(form, &schema)
166    };
167
168    let result = client.submit_and_poll("extract", form, progress).await?;
169
170    let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
171    cache.set(
172        &cache_key,
173        &result,
174        "extract",
175        file_hash.as_deref(),
176        file_path_str.as_deref(),
177    )?;
178
179    output_result(&result, args.output.as_ref())?;
180
181    Ok(())
182}
183
184fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
185    let json_output = serde_json::to_string_pretty(result)?;
186
187    if let Some(path) = output_file {
188        fs::write(path, &json_output)?;
189    } else {
190        println!("{}", json_output);
191    }
192
193    Ok(())
194}
195
196#[derive(Args, Debug)]
197pub struct ExtractScoreArgs {
198    /// Checkpoint ID from extraction with save_checkpoint=true
199    #[arg(long, value_name = "ID")]
200    pub checkpoint_id: String,
201
202    /// Skip local cache lookup
203    #[arg(long, help_heading = "Cache Options")]
204    pub skip_cache: bool,
205
206    /// Write result to file
207    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
208    pub output: Option<PathBuf>,
209
210    /// Request timeout in seconds
211    #[arg(
212        long,
213        default_value = "300",
214        value_name = "SECS",
215        help_heading = "Advanced Options"
216    )]
217    pub timeout: u64,
218}
219
220pub async fn execute_score(args: ExtractScoreArgs, progress: &Progress) -> Result<()> {
221    let client = DatalabClient::new(Some(args.timeout))?;
222    let cache = Cache::new()?;
223
224    progress.start("extract-score", None);
225
226    let cache_params = json!({ "checkpoint_id": args.checkpoint_id });
227    let cache_key = Cache::generate_key(None, None, "extract/score", &cache_params);
228
229    if !args.skip_cache {
230        if let Some(cached) = cache.get(&cache_key) {
231            progress.cache_hit(&cache_key);
232            output_result(&cached, args.output.as_ref())?;
233            return Ok(());
234        }
235    }
236
237    let form = Form::new().text("checkpoint_id", args.checkpoint_id.clone());
238    let result = client
239        .submit_and_poll("extract/score", form, progress)
240        .await?;
241
242    cache.set(&cache_key, &result, "extract/score", None, None)?;
243
244    output_result(&result, args.output.as_ref())?;
245
246    Ok(())
247}