datalab-cli 0.1.0

A powerful CLI for converting, extracting, and processing documents using the Datalab API
Documentation
use crate::cache::Cache;
use crate::client::{add_form_field, build_form_with_file, DatalabClient};
use crate::error::{DatalabError, Result};
use crate::output::Progress;
use clap::Args;
use reqwest::multipart::Form;
use serde_json::json;
use std::fs;
use std::path::PathBuf;

#[derive(Args, Debug)]
pub struct ExtractArgs {
    /// File path or URL to extract from
    #[arg(value_name = "FILE|URL")]
    pub input: String,

    /// JSON schema file or string defining extraction fields
    #[arg(long, value_name = "SCHEMA", help_heading = "Schema Options")]
    pub schema: String,

    /// Checkpoint ID to reuse parsed document
    #[arg(long, value_name = "ID", help_heading = "Processing Options")]
    pub checkpoint_id: Option<String>,

    /// Processing mode: fast, balanced, accurate
    #[arg(
        long,
        default_value = "fast",
        value_name = "MODE",
        help_heading = "Processing Options"
    )]
    pub mode: String,

    /// Maximum pages to process
    #[arg(long, value_name = "N", help_heading = "Processing Options")]
    pub max_pages: Option<u32>,

    /// Page range (e.g., "0-5,10")
    #[arg(long, value_name = "RANGE", help_heading = "Processing Options")]
    pub page_range: Option<String>,

    /// Save checkpoint for reuse
    #[arg(long, help_heading = "Processing Options")]
    pub save_checkpoint: bool,

    /// Include per-field confidence scores
    #[arg(long, help_heading = "Output Options")]
    pub include_scores: bool,

    /// Skip local cache lookup
    #[arg(long, help_heading = "Cache Options")]
    pub skip_cache: bool,

    /// Write result to file
    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
    pub output: Option<PathBuf>,

    /// Request timeout in seconds
    #[arg(
        long,
        default_value = "300",
        value_name = "SECS",
        help_heading = "Advanced Options"
    )]
    pub timeout: u64,
}

impl ExtractArgs {
    fn to_cache_params(&self) -> serde_json::Value {
        json!({
            "schema": self.schema,
            "checkpoint_id": self.checkpoint_id,
            "mode": self.mode,
            "max_pages": self.max_pages,
            "page_range": self.page_range,
            "save_checkpoint": self.save_checkpoint,
            "include_scores": self.include_scores,
        })
    }

    fn get_schema(&self) -> Result<String> {
        let schema_path = PathBuf::from(&self.schema);
        if schema_path.exists() {
            Ok(fs::read_to_string(&schema_path)?)
        } else {
            serde_json::from_str::<serde_json::Value>(&self.schema).map_err(|_| {
                DatalabError::InvalidInput(
                    "Schema must be valid JSON or a path to a JSON file".to_string(),
                )
            })?;
            Ok(self.schema.clone())
        }
    }

    fn add_to_form(&self, mut form: Form, schema: &str) -> Form {
        form = add_form_field(form, "page_schema", schema);
        form = add_form_field(form, "mode", &self.mode);

        if let Some(ref checkpoint_id) = self.checkpoint_id {
            form = add_form_field(form, "checkpoint_id", checkpoint_id);
        }
        if let Some(max_pages) = self.max_pages {
            form = add_form_field(form, "max_pages", &max_pages.to_string());
        }
        if let Some(ref page_range) = self.page_range {
            form = add_form_field(form, "page_range", page_range);
        }
        if self.save_checkpoint {
            form = add_form_field(form, "save_checkpoint", "true");
        }
        if self.include_scores {
            form = add_form_field(form, "include_scores", "true");
        }

        form
    }
}

pub async fn execute(args: ExtractArgs, progress: &Progress) -> Result<()> {
    let client = DatalabClient::new(Some(args.timeout))?;
    let cache = Cache::new()?;

    let schema = args.get_schema()?;

    let is_url = args.input.starts_with("http://") || args.input.starts_with("https://");
    let file_path = if is_url {
        None
    } else {
        Some(PathBuf::from(&args.input))
    };

    let file_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
    progress.start("extract", file_str.as_deref());

    let file_hash = if let Some(ref path) = file_path {
        if !path.exists() {
            return Err(DatalabError::FileNotFound(path.clone()));
        }
        Some(Cache::hash_file(path)?)
    } else {
        None
    };

    let cache_params = args.to_cache_params();
    let cache_key = Cache::generate_key(
        file_hash.as_deref(),
        if is_url { Some(&args.input) } else { None },
        "extract",
        &cache_params,
    );

    if !args.skip_cache {
        if let Some(cached) = cache.get(&cache_key) {
            progress.cache_hit(&cache_key);
            output_result(&cached, args.output.as_ref())?;
            return Ok(());
        }
    }

    let form = if let Some(ref path) = file_path {
        let (form, _) = build_form_with_file(path)?;
        args.add_to_form(form, &schema)
    } else {
        let form = Form::new().text("file_url", args.input.clone());
        args.add_to_form(form, &schema)
    };

    let result = client.submit_and_poll("extract", form, progress).await?;

    let file_path_str = file_path.as_ref().map(|p| p.to_string_lossy().to_string());
    cache.set(
        &cache_key,
        &result,
        "extract",
        file_hash.as_deref(),
        file_path_str.as_deref(),
    )?;

    output_result(&result, args.output.as_ref())?;

    Ok(())
}

fn output_result(result: &serde_json::Value, output_file: Option<&PathBuf>) -> Result<()> {
    let json_output = serde_json::to_string_pretty(result)?;

    if let Some(path) = output_file {
        fs::write(path, &json_output)?;
    } else {
        println!("{}", json_output);
    }

    Ok(())
}

#[derive(Args, Debug)]
pub struct ExtractScoreArgs {
    /// Checkpoint ID from extraction with save_checkpoint=true
    #[arg(long, value_name = "ID")]
    pub checkpoint_id: String,

    /// Skip local cache lookup
    #[arg(long, help_heading = "Cache Options")]
    pub skip_cache: bool,

    /// Write result to file
    #[arg(long, short, value_name = "FILE", help_heading = "Output Options")]
    pub output: Option<PathBuf>,

    /// Request timeout in seconds
    #[arg(
        long,
        default_value = "300",
        value_name = "SECS",
        help_heading = "Advanced Options"
    )]
    pub timeout: u64,
}

pub async fn execute_score(args: ExtractScoreArgs, progress: &Progress) -> Result<()> {
    let client = DatalabClient::new(Some(args.timeout))?;
    let cache = Cache::new()?;

    progress.start("extract-score", None);

    let cache_params = json!({ "checkpoint_id": args.checkpoint_id });
    let cache_key = Cache::generate_key(None, None, "extract/score", &cache_params);

    if !args.skip_cache {
        if let Some(cached) = cache.get(&cache_key) {
            progress.cache_hit(&cache_key);
            output_result(&cached, args.output.as_ref())?;
            return Ok(());
        }
    }

    let form = Form::new().text("checkpoint_id", args.checkpoint_id.clone());
    let result = client
        .submit_and_poll("extract/score", form, progress)
        .await?;

    cache.set(&cache_key, &result, "extract/score", None, None)?;

    output_result(&result, args.output.as_ref())?;

    Ok(())
}