datalab_cli/commands/
segment.rs1use crate::cache::Cache;
2use crate::client::{add_form_field, build_form_with_file, DatalabClient};
3use crate::error::{DatalabError, Result};
4use crate::output::Progress;
5use clap::Args;
6use reqwest::multipart::Form;
7use serde_json::json;
8use std::fs;
9use std::path::PathBuf;
10
11#[derive(Args, Debug)]
12pub struct SegmentArgs {
13 #[arg(value_name = "FILE|URL")]
15 pub input: String,
16
17 #[arg(long, value_name = "SCHEMA")]
19 pub schema: String,
20
21 #[arg(long, value_name = "ID", help_heading = "Processing Options")]
23 pub checkpoint_id: Option<String>,
24
25 #[arg(
27 long,
28 default_value = "fast",
29 value_name = "MODE",
30 help_heading = "Processing Options"
31 )]
32 pub mode: String,
33
34 #[arg(long, value_name = "N", help_heading = "Processing Options")]
36 pub max_pages: Option<u32>,
37
38 #[arg(long, help_heading = "Processing Options")]
40 pub save_checkpoint: bool,
41
42 #[arg(long, help_heading = "Cache Options")]
44 pub skip_cache: bool,
45
46 #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
48 pub output: Option<PathBuf>,
49
50 #[arg(
52 long,
53 default_value = "300",
54 value_name = "SECS",
55 help_heading = "Advanced Options"
56 )]
57 pub timeout: u64,
58}
59
60impl SegmentArgs {
61 fn to_cache_params(&self) -> serde_json::Value {
62 json!({
63 "schema": self.schema,
64 "checkpoint_id": self.checkpoint_id,
65 "mode": self.mode,
66 "max_pages": self.max_pages,
67 "save_checkpoint": self.save_checkpoint,
68 })
69 }
70
71 fn get_schema(&self) -> Result<String> {
72 let schema_path = PathBuf::from(&self.schema);
73 if schema_path.exists() {
74 Ok(fs::read_to_string(&schema_path)?)
75 } else {
76 serde_json::from_str::<serde_json::Value>(&self.schema).map_err(|_| {
77 DatalabError::InvalidInput(
78 "Schema must be valid JSON or a path to a JSON file".to_string(),
79 )
80 })?;
81 Ok(self.schema.clone())
82 }
83 }
84
85 fn add_to_form(&self, mut form: Form, schema: &str) -> Form {
86 form = add_form_field(form, "segmentation_schema", schema);
87 form = add_form_field(form, "mode", &self.mode);
88
89 if let Some(ref checkpoint_id) = self.checkpoint_id {
90 form = add_form_field(form, "checkpoint_id", checkpoint_id);
91 }
92 if let Some(max_pages) = self.max_pages {
93 form = add_form_field(form, "max_pages", &max_pages.to_string());
94 }
95 if self.save_checkpoint {
96 form = add_form_field(form, "save_checkpoint", "true");
97 }
98
99 form
100 }
101}
102
103pub async fn execute(args: SegmentArgs, progress: &Progress) -> Result<()> {
104 let client = DatalabClient::new(Some(args.timeout))?;
105 let cache = Cache::new()?;
106
107 let schema = args.get_schema()?;
108
109 let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
110 let file_path = if is_url {
111 None
112 } else {
113 Some(PathBuf::from(&args.input))
114 };
115
116 let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
117 progress.start("segment", file_str.as_deref());
118
119 let file_hash = if let Some(ref path) = file_path {
120 if !path.exists() {
121 return Err(DatalabError::FileNotFound(path.clone()));
122 }
123 Some(Cache::hash_file(path)?)
124 } else {
125 None
126 };
127
128 let cache_params = args.to_cache_params();
129 let cache_key = Cache::generate_key(
130 file_hash.as_deref(),
131 if is_url { Some(&args.input) } else { None },
132 "segment",
133 &cache_params,
134 );
135
136 if !args.skip_cache {
137 if let Some(cached) = cache.get(&cache_key) {
138 progress.cache_hit(&cache_key);
139 output_result(&cached, args.output.as_ref())?;
140 return Ok(());
141 }
142 }
143
144 let form = if let Some(ref path) = file_path {
145 let (form, _) = build_form_with_file(path)?;
146 args.add_to_form(form, &schema)
147 } else {
148 let form = Form::new().text("file_url", args.input.clone());
149 args.add_to_form(form, &schema)
150 };
151
152 let result = client.submit_and_poll("segment", form, progress).await?;
153
154 let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
155 cache.set(
156 &cache_key,
157 &result,
158 "segment",
159 file_hash.as_deref(),
160 file_path_str.as_deref(),
161 )?;
162
163 output_result(&result, args.output.as_ref())?;
164
165 Ok(())
166}
167
168fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
169 let json_output = serde_json::to_string_pretty(result)?;
170
171 if let Some(path) = output_file {
172 fs::write(path, &json_output)?;
173 } else {
174 println!("{}", json_output);
175 }
176
177 Ok(())
178}