Skip to main content

datalab_cli/commands/
fill.rs

1use crate::cache::Cache;
2use crate::client::{add_form_field, build_form_with_file, DatalabClient};
3use crate::error::{DatalabError, Result};
4use crate::output::Progress;
5use base64::Engine;
6use clap::Args;
7use reqwest::multipart::Form;
8use serde_json::json;
9use std::fs;
10use std::path::PathBuf;
11
12#[derive(Args, Debug)]
13pub struct FillArgs {
14    /// File path or URL of form to fill
15    #[arg(value_name = "FILE|URL")]
16    pub input: String,
17
18    /// JSON file or string mapping field names to values
19    #[arg(long, value_name = "JSON")]
20    pub fields: String,
21
22    /// Additional context for field matching
23    #[arg(long, value_name = "TEXT", help_heading = "Matching Options")]
24    pub context: Option<String>,
25
26    /// Field matching strictness (0.0-1.0)
27    #[arg(
28        long,
29        default_value = "0.5",
30        value_name = "THRESHOLD",
31        help_heading = "Matching Options"
32    )]
33    pub confidence_threshold: f32,
34
35    /// Maximum pages to process
36    #[arg(long, value_name = "N", help_heading = "Processing Options")]
37    pub max_pages: Option<u32>,
38
39    /// Page range (e.g., "0-5,10")
40    #[arg(long, value_name = "RANGE", help_heading = "Processing Options")]
41    pub page_range: Option<String>,
42
43    /// Skip local cache lookup
44    #[arg(long, help_heading = "Cache Options")]
45    pub skip_cache: bool,
46
47    /// Write filled form to file (binary output)
48    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
49    pub output: Option<PathBuf>,
50
51    /// Request timeout in seconds
52    #[arg(
53        long,
54        default_value = "300",
55        value_name = "SECS",
56        help_heading = "Advanced Options"
57    )]
58    pub timeout: u64,
59}
60
61impl FillArgs {
62    fn to_cache_params(&self) -> serde_json::Value {
63        json!({
64            "fields": self.fields,
65            "context": self.context,
66            "confidence_threshold": self.confidence_threshold,
67            "max_pages": self.max_pages,
68            "page_range": self.page_range,
69        })
70    }
71
72    fn get_fields(&self) -> Result<String> {
73        let fields_path = PathBuf::from(&self.fields);
74        if fields_path.exists() {
75            Ok(fs::read_to_string(&fields_path)?)
76        } else {
77            serde_json::from_str::<serde_json::Value>(&self.fields).map_err(|_| {
78                DatalabError::InvalidInput(
79                    "Fields must be valid JSON or a path to a JSON file".to_string(),
80                )
81            })?;
82            Ok(self.fields.clone())
83        }
84    }
85
86    fn add_to_form(&self, mut form: Form, fields: &str) -> Form {
87        form = add_form_field(form, "field_data", fields);
88        form = add_form_field(
89            form,
90            "confidence_threshold",
91            &self.confidence_threshold.to_string(),
92        );
93
94        if let Some(ref context) = self.context {
95            form = add_form_field(form, "context", context);
96        }
97        if let Some(max_pages) = self.max_pages {
98            form = add_form_field(form, "max_pages", &max_pages.to_string());
99        }
100        if let Some(ref page_range) = self.page_range {
101            form = add_form_field(form, "page_range", page_range);
102        }
103
104        form
105    }
106}
107
108pub async fn execute(args: FillArgs, progress: &Progress) -> Result<()> {
109    let client = DatalabClient::new(Some(args.timeout))?;
110    let cache = Cache::new()?;
111
112    let fields = args.get_fields()?;
113
114    let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
115    let file_path = if is_url {
116        None
117    } else {
118        Some(PathBuf::from(&args.input))
119    };
120
121    let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
122    progress.start("fill", file_str.as_deref());
123
124    let file_hash = if let Some(ref path) = file_path {
125        if !path.exists() {
126            return Err(DatalabError::FileNotFound(path.clone()));
127        }
128        Some(Cache::hash_file(path)?)
129    } else {
130        None
131    };
132
133    let cache_params = args.to_cache_params();
134    let cache_key = Cache::generate_key(
135        file_hash.as_deref(),
136        if is_url { Some(&args.input) } else { None },
137        "fill",
138        &cache_params,
139    );
140
141    if !args.skip_cache {
142        if let Some(cached) = cache.get(&cache_key) {
143            progress.cache_hit(&cache_key);
144            output_result(&cached, args.output.as_ref())?;
145            return Ok(());
146        }
147    }
148
149    let form = if let Some(ref path) = file_path {
150        let (form, _) = build_form_with_file(path)?;
151        args.add_to_form(form, &fields)
152    } else {
153        let form = Form::new().text("file_url", args.input.clone());
154        args.add_to_form(form, &fields)
155    };
156
157    let result = client.submit_and_poll("fill", form, progress).await?;
158
159    let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
160    cache.set(
161        &cache_key,
162        &result,
163        "fill",
164        file_hash.as_deref(),
165        file_path_str.as_deref(),
166    )?;
167
168    output_result(&result, args.output.as_ref())?;
169
170    Ok(())
171}
172
173fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
174    if let Some(path) = output_file {
175        if let Some(base64_data) = result.get("output_base64").and_then(|v| v.as_str()) {
176            let decoded = base64::engine::general_purpose::STANDARD
177                .decode(base64_data)
178                .map_err(|e| DatalabError::InvalidInput(format!("Invalid base64: {}", e)))?;
179            fs::write(path, &decoded)?;
180
181            let mut meta = result.clone();
182            meta.as_object_mut().map(|o| o.remove("output_base64"));
183            println!("{}", serde_json::to_string_pretty(&meta)?);
184        } else {
185            println!("{}", serde_json::to_string_pretty(result)?);
186        }
187    } else {
188        println!("{}", serde_json::to_string_pretty(result)?);
189    }
190
191    Ok(())
192}