Skip to main content

datalab_cli/commands/
segment.rs

1use crate::cache::Cache;
2use crate::client::{add_form_field, build_form_with_file, DatalabClient};
3use crate::error::{DatalabError, Result};
4use crate::output::Progress;
5use clap::Args;
6use reqwest::multipart::Form;
7use serde_json::json;
8use std::fs;
9use std::path::PathBuf;
10
11#[derive(Args, Debug)]
12pub struct SegmentArgs {
13    /// File path or URL to segment
14    #[arg(value_name = "FILE|URL")]
15    pub input: String,
16
17    /// JSON schema file or string defining segmentation structure
18    #[arg(long, value_name = "SCHEMA")]
19    pub schema: String,
20
21    /// Checkpoint ID to reuse parsed document
22    #[arg(long, value_name = "ID", help_heading = "Processing Options")]
23    pub checkpoint_id: Option<String>,
24
25    /// Processing mode: fast, balanced, accurate
26    #[arg(
27        long,
28        default_value = "fast",
29        value_name = "MODE",
30        help_heading = "Processing Options"
31    )]
32    pub mode: String,
33
34    /// Maximum pages to process
35    #[arg(long, value_name = "N", help_heading = "Processing Options")]
36    pub max_pages: Option<u32>,
37
38    /// Save checkpoint for reuse
39    #[arg(long, help_heading = "Processing Options")]
40    pub save_checkpoint: bool,
41
42    /// Skip local cache lookup
43    #[arg(long, help_heading = "Cache Options")]
44    pub skip_cache: bool,
45
46    /// Write result to file
47    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
48    pub output: Option<PathBuf>,
49
50    /// Request timeout in seconds
51    #[arg(
52        long,
53        default_value = "300",
54        value_name = "SECS",
55        help_heading = "Advanced Options"
56    )]
57    pub timeout: u64,
58}
59
60impl SegmentArgs {
61    fn to_cache_params(&self) -> serde_json::Value {
62        json!({
63            "schema": self.schema,
64            "checkpoint_id": self.checkpoint_id,
65            "mode": self.mode,
66            "max_pages": self.max_pages,
67            "save_checkpoint": self.save_checkpoint,
68        })
69    }
70
71    fn get_schema(&self) -> Result<String> {
72        let schema_path = PathBuf::from(&self.schema);
73        if schema_path.exists() {
74            Ok(fs::read_to_string(&schema_path)?)
75        } else {
76            serde_json::from_str::<serde_json::Value>(&self.schema).map_err(|_| {
77                DatalabError::InvalidInput(
78                    "Schema must be valid JSON or a path to a JSON file".to_string(),
79                )
80            })?;
81            Ok(self.schema.clone())
82        }
83    }
84
85    fn add_to_form(&self, mut form: Form, schema: &str) -> Form {
86        form = add_form_field(form, "segmentation_schema", schema);
87        form = add_form_field(form, "mode", &self.mode);
88
89        if let Some(ref checkpoint_id) = self.checkpoint_id {
90            form = add_form_field(form, "checkpoint_id", checkpoint_id);
91        }
92        if let Some(max_pages) = self.max_pages {
93            form = add_form_field(form, "max_pages", &max_pages.to_string());
94        }
95        if self.save_checkpoint {
96            form = add_form_field(form, "save_checkpoint", "true");
97        }
98
99        form
100    }
101}
102
103pub async fn execute(args: SegmentArgs, progress: &Progress) -> Result<()> {
104    let client = DatalabClient::new(Some(args.timeout))?;
105    let cache = Cache::new()?;
106
107    let schema = args.get_schema()?;
108
109    let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
110    let file_path = if is_url {
111        None
112    } else {
113        Some(PathBuf::from(&args.input))
114    };
115
116    let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
117    progress.start("segment", file_str.as_deref());
118
119    let file_hash = if let Some(ref path) = file_path {
120        if !path.exists() {
121            return Err(DatalabError::FileNotFound(path.clone()));
122        }
123        Some(Cache::hash_file(path)?)
124    } else {
125        None
126    };
127
128    let cache_params = args.to_cache_params();
129    let cache_key = Cache::generate_key(
130        file_hash.as_deref(),
131        if is_url { Some(&args.input) } else { None },
132        "segment",
133        &cache_params,
134    );
135
136    if !args.skip_cache {
137        if let Some(cached) = cache.get(&cache_key) {
138            progress.cache_hit(&cache_key);
139            output_result(&cached, args.output.as_ref())?;
140            return Ok(());
141        }
142    }
143
144    let form = if let Some(ref path) = file_path {
145        let (form, _) = build_form_with_file(path)?;
146        args.add_to_form(form, &schema)
147    } else {
148        let form = Form::new().text("file_url", args.input.clone());
149        args.add_to_form(form, &schema)
150    };
151
152    let result = client.submit_and_poll("segment", form, progress).await?;
153
154    let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
155    cache.set(
156        &cache_key,
157        &result,
158        "segment",
159        file_hash.as_deref(),
160        file_path_str.as_deref(),
161    )?;
162
163    output_result(&result, args.output.as_ref())?;
164
165    Ok(())
166}
167
168fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
169    let json_output = serde_json::to_string_pretty(result)?;
170
171    if let Some(path) = output_file {
172        fs::write(path, &json_output)?;
173    } else {
174        println!("{}", json_output);
175    }
176
177    Ok(())
178}