1use crate::cache::Cache;
2use crate::client::{add_form_field, build_form_with_file, DatalabClient};
3use crate::error::{DatalabError, Result};
4use crate::output::Progress;
5use clap::Args;
6use reqwest::multipart::Form;
7use serde_json::json;
8use std::fs;
9use std::path::PathBuf;
10
11#[derive(Args, Debug)]
12pub struct ExtractArgs {
13 #[arg(value_name = "FILE|URL")]
15 pub input: String,
16
17 #[arg(long, value_name = "SCHEMA", help_heading = "Schema Options")]
19 pub schema: String,
20
21 #[arg(long, value_name = "ID", help_heading = "Processing Options")]
23 pub checkpoint_id: Option<String>,
24
25 #[arg(
27 long,
28 default_value = "fast",
29 value_name = "MODE",
30 help_heading = "Processing Options"
31 )]
32 pub mode: String,
33
34 #[arg(long, value_name = "N", help_heading = "Processing Options")]
36 pub max_pages: Option<u32>,
37
38 #[arg(long, value_name = "RANGE", help_heading = "Processing Options")]
40 pub page_range: Option<String>,
41
42 #[arg(long, help_heading = "Processing Options")]
44 pub save_checkpoint: bool,
45
46 #[arg(long, help_heading = "Output Options")]
48 pub include_scores: bool,
49
50 #[arg(long, help_heading = "Cache Options")]
52 pub skip_cache: bool,
53
54 #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
56 pub output: Option<PathBuf>,
57
58 #[arg(
60 long,
61 default_value = "300",
62 value_name = "SECS",
63 help_heading = "Advanced Options"
64 )]
65 pub timeout: u64,
66}
67
68impl ExtractArgs {
69 fn to_cache_params(&self) -> serde_json::Value {
70 json!({
71 "schema": self.schema,
72 "checkpoint_id": self.checkpoint_id,
73 "mode": self.mode,
74 "max_pages": self.max_pages,
75 "page_range": self.page_range,
76 "save_checkpoint": self.save_checkpoint,
77 "include_scores": self.include_scores,
78 })
79 }
80
81 fn get_schema(&self) -> Result<String> {
82 let schema_path = PathBuf::from(&self.schema);
83 if schema_path.exists() {
84 Ok(fs::read_to_string(&schema_path)?)
85 } else {
86 serde_json::from_str::<serde_json::Value>(&self.schema).map_err(|_| {
87 DatalabError::InvalidInput(
88 "Schema must be valid JSON or a path to a JSON file".to_string(),
89 )
90 })?;
91 Ok(self.schema.clone())
92 }
93 }
94
95 fn add_to_form(&self, mut form: Form, schema: &str) -> Form {
96 form = add_form_field(form, "page_schema", schema);
97 form = add_form_field(form, "mode", &self.mode);
98
99 if let Some(ref checkpoint_id) = self.checkpoint_id {
100 form = add_form_field(form, "checkpoint_id", checkpoint_id);
101 }
102 if let Some(max_pages) = self.max_pages {
103 form = add_form_field(form, "max_pages", &max_pages.to_string());
104 }
105 if let Some(ref page_range) = self.page_range {
106 form = add_form_field(form, "page_range", page_range);
107 }
108 if self.save_checkpoint {
109 form = add_form_field(form, "save_checkpoint", "true");
110 }
111 if self.include_scores {
112 form = add_form_field(form, "include_scores", "true");
113 }
114
115 form
116 }
117}
118
119pub async fn execute(args: ExtractArgs, progress: &Progress) -> Result<()> {
120 let client = DatalabClient::new(Some(args.timeout))?;
121 let cache = Cache::new()?;
122
123 let schema = args.get_schema()?;
124
125 let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
126 let file_path = if is_url {
127 None
128 } else {
129 Some(PathBuf::from(&args.input))
130 };
131
132 let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
133 progress.start("extract", file_str.as_deref());
134
135 let file_hash = if let Some(ref path) = file_path {
136 if !path.exists() {
137 return Err(DatalabError::FileNotFound(path.clone()));
138 }
139 Some(Cache::hash_file(path)?)
140 } else {
141 None
142 };
143
144 let cache_params = args.to_cache_params();
145 let cache_key = Cache::generate_key(
146 file_hash.as_deref(),
147 if is_url { Some(&args.input) } else { None },
148 "extract",
149 &cache_params,
150 );
151
152 if !args.skip_cache {
153 if let Some(cached) = cache.get(&cache_key) {
154 progress.cache_hit(&cache_key);
155 output_result(&cached, args.output.as_ref())?;
156 return Ok(());
157 }
158 }
159
160 let form = if let Some(ref path) = file_path {
161 let (form, _) = build_form_with_file(path)?;
162 args.add_to_form(form, &schema)
163 } else {
164 let form = Form::new().text("file_url", args.input.clone());
165 args.add_to_form(form, &schema)
166 };
167
168 let result = client.submit_and_poll("extract", form, progress).await?;
169
170 let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
171 cache.set(
172 &cache_key,
173 &result,
174 "extract",
175 file_hash.as_deref(),
176 file_path_str.as_deref(),
177 )?;
178
179 output_result(&result, args.output.as_ref())?;
180
181 Ok(())
182}
183
184fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
185 let json_output = serde_json::to_string_pretty(result)?;
186
187 if let Some(path) = output_file {
188 fs::write(path, &json_output)?;
189 } else {
190 println!("{}", json_output);
191 }
192
193 Ok(())
194}
195
196#[derive(Args, Debug)]
197pub struct ExtractScoreArgs {
198 #[arg(long, value_name = "ID")]
200 pub checkpoint_id: String,
201
202 #[arg(long, help_heading = "Cache Options")]
204 pub skip_cache: bool,
205
206 #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
208 pub output: Option<PathBuf>,
209
210 #[arg(
212 long,
213 default_value = "300",
214 value_name = "SECS",
215 help_heading = "Advanced Options"
216 )]
217 pub timeout: u64,
218}
219
220pub async fn execute_score(args: ExtractScoreArgs, progress: &Progress) -> Result<()> {
221 let client = DatalabClient::new(Some(args.timeout))?;
222 let cache = Cache::new()?;
223
224 progress.start("extract-score", None);
225
226 let cache_params = json!({ "checkpoint_id": args.checkpoint_id });
227 let cache_key = Cache::generate_key(None, None, "extract/score", &cache_params);
228
229 if !args.skip_cache {
230 if let Some(cached) = cache.get(&cache_key) {
231 progress.cache_hit(&cache_key);
232 output_result(&cached, args.output.as_ref())?;
233 return Ok(());
234 }
235 }
236
237 let form = Form::new().text("checkpoint_id", args.checkpoint_id.clone());
238 let result = client
239 .submit_and_poll("extract/score", form, progress)
240 .await?;
241
242 cache.set(&cache_key, &result, "extract/score", None, None)?;
243
244 output_result(&result, args.output.as_ref())?;
245
246 Ok(())
247}