#![deny(unsafe_code)]
mod commands;
mod style;
use anyhow::{Context, Result};
use base64::{Engine as _, engine::general_purpose::STANDARD};
use clap::{CommandFactory, Parser, Subcommand};
#[cfg(feature = "embeddings")]
use commands::embed_command;
#[cfg(feature = "mcp")]
use commands::mcp_command;
use commands::overrides::ExtractionOverrides;
#[cfg(feature = "api")]
use commands::serve_command;
use commands::{
batch_command, chunk_command, clear_command, extract_command,
extract_structured::{ExtractStructuredArgs, extract_structured_command},
load_config, manifest_command, stats_command, warm_command,
};
use kreuzberg::{OutputFormat as ContentOutputFormat, detect_mime_type};
use serde_json::json;
use std::path::{Path, PathBuf};
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "kreuzberg")]
#[command(version, about, long_about = None)]
struct Cli {
#[arg(long, global = true)]
log_level: Option<String>,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Extract {
path: PathBuf,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long)]
config_json: Option<String>,
#[arg(long)]
config_json_base64: Option<String>,
#[arg(short, long)]
mime_type: Option<String>,
#[arg(short, long, default_value = "text")]
format: WireFormat,
#[command(flatten)]
overrides: ExtractionOverrides,
},
ExtractStructured {
path: PathBuf,
#[arg(long)]
schema: PathBuf,
#[arg(long)]
model: String,
#[arg(long)]
api_key: Option<String>,
#[arg(long)]
prompt: Option<String>,
#[arg(long, default_value = "extraction")]
schema_name: Option<String>,
#[arg(long)]
strict: bool,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(short, long, default_value = "json")]
format: WireFormat,
},
Batch {
paths: Vec<PathBuf>,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long)]
config_json: Option<String>,
#[arg(long)]
config_json_base64: Option<String>,
#[arg(short, long, default_value = "json")]
format: WireFormat,
#[command(flatten)]
overrides: ExtractionOverrides,
#[arg(long)]
file_configs: Option<PathBuf>,
},
Detect {
path: PathBuf,
#[arg(short, long, default_value = "text")]
format: WireFormat,
},
Formats {
#[arg(short, long, default_value = "text")]
format: WireFormat,
},
Version {
#[arg(short, long, default_value = "text")]
format: WireFormat,
},
Cache {
#[command(subcommand)]
command: CacheCommands,
},
#[cfg(feature = "api")]
Serve {
#[arg(short = 'H', long)]
host: Option<String>,
#[arg(short, long)]
port: Option<u16>,
#[arg(short, long)]
config: Option<PathBuf>,
},
#[cfg(feature = "mcp")]
Mcp {
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "stdio")]
transport: String,
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(long, default_value = "8001")]
port: u16,
},
#[cfg(feature = "api")]
Api {
#[command(subcommand)]
command: ApiCommands,
},
#[cfg(feature = "embeddings")]
Embed {
#[arg(long)]
text: Vec<String>,
#[arg(long, default_value = "balanced")]
preset: String,
#[arg(long, default_value = "local")]
provider: String,
#[arg(long)]
model: Option<String>,
#[arg(long)]
api_key: Option<String>,
#[arg(short, long, default_value = "json")]
format: WireFormat,
},
Chunk {
#[arg(long)]
text: Option<String>,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long)]
chunk_size: Option<usize>,
#[arg(long)]
chunk_overlap: Option<usize>,
#[arg(long, default_value = "text")]
chunker_type: String,
#[arg(long)]
chunking_tokenizer: Option<String>,
#[arg(long)]
topic_threshold: Option<f32>,
#[arg(short, long, default_value = "json")]
format: WireFormat,
},
Completions {
#[arg(value_enum)]
shell: clap_complete::Shell,
},
}
#[cfg(feature = "api")]
#[derive(Subcommand)]
enum ApiCommands {
Schema,
}
#[derive(Subcommand)]
enum CacheCommands {
Stats {
#[arg(short, long)]
cache_dir: Option<PathBuf>,
#[arg(short, long, default_value = "text")]
format: WireFormat,
},
Clear {
#[arg(short, long)]
cache_dir: Option<PathBuf>,
#[arg(short, long, default_value = "text")]
format: WireFormat,
},
Manifest {
#[arg(short, long, default_value = "json")]
format: WireFormat,
},
Warm {
#[arg(short, long)]
cache_dir: Option<PathBuf>,
#[arg(short, long, default_value = "text")]
format: WireFormat,
#[arg(long)]
all_embeddings: bool,
#[arg(long, value_name = "PRESET")]
embedding_model: Option<String>,
#[arg(
long,
help = "Download all table structure models including SLANeXT variants (~730MB)"
)]
all_table_models: bool,
#[arg(long)]
all_grammars: bool,
#[arg(long, value_name = "GROUPS", value_delimiter = ',')]
grammar_groups: Option<Vec<String>>,
#[arg(long, value_name = "LANGUAGES", value_delimiter = ',')]
grammars: Option<Vec<String>>,
},
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum WireFormat {
Text,
Json,
Toon,
}
impl std::str::FromStr for WireFormat {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"text" => Ok(WireFormat::Text),
"json" => Ok(WireFormat::Json),
"toon" => Ok(WireFormat::Toon),
_ => Err(format!("Invalid format: {}. Use 'text', 'json', or 'toon'", s)),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]
enum ContentOutputFormatArg {
Plain,
Markdown,
Djot,
Html,
Json,
}
impl From<ContentOutputFormatArg> for ContentOutputFormat {
fn from(arg: ContentOutputFormatArg) -> Self {
match arg {
ContentOutputFormatArg::Plain => ContentOutputFormat::Plain,
ContentOutputFormatArg::Markdown => ContentOutputFormat::Markdown,
ContentOutputFormatArg::Djot => ContentOutputFormat::Djot,
ContentOutputFormatArg::Html => ContentOutputFormat::Html,
ContentOutputFormatArg::Json => ContentOutputFormat::Json,
}
}
}
fn validate_file_exists(path: &Path) -> Result<()> {
if !path.exists() {
anyhow::bail!(
"File not found: '{}'. Please check that the file exists and is accessible.",
path.display()
);
}
if !path.is_file() {
anyhow::bail!(
"Path is not a file: '{}'. Please provide a path to a regular file.",
path.display()
);
}
Ok(())
}
fn validate_chunk_params(chunk_size: Option<usize>, chunk_overlap: Option<usize>) -> Result<()> {
if let Some(size) = chunk_size {
if size == 0 {
anyhow::bail!("Invalid chunk size: {}. Chunk size must be greater than 0.", size);
}
if size > 1_000_000 {
anyhow::bail!(
"Invalid chunk size: {}. Chunk size must be less than 1,000,000 characters to avoid excessive memory usage.",
size
);
}
}
if let Some(overlap) = chunk_overlap
&& let Some(size) = chunk_size
&& overlap >= size
{
anyhow::bail!(
"Invalid chunk overlap: {}. Overlap ({}) must be less than chunk size ({}).",
overlap,
overlap,
size
);
}
Ok(())
}
fn validate_batch_paths(paths: &[PathBuf]) -> Result<()> {
if paths.is_empty() {
anyhow::bail!("No files provided for batch extraction. Please provide at least one file path.");
}
for (i, path) in paths.iter().enumerate() {
validate_file_exists(path).with_context(|| format!("Invalid file at position {}", i + 1))?;
}
Ok(())
}
fn apply_json_overrides(
config: &mut kreuzberg::ExtractionConfig,
config_json: Option<String>,
config_json_base64: Option<String>,
) -> Result<()> {
if let Some(json_str) = config_json {
let json_value: serde_json::Value =
serde_json::from_str(&json_str).context("Failed to parse --config-json as JSON")?;
*config =
merge_json_into_config(config, json_value).context("Failed to merge --config-json with file config")?;
} else if let Some(base64_str) = config_json_base64 {
let json_bytes = STANDARD
.decode(&base64_str)
.context("Failed to decode base64 in --config-json-base64")?;
let json_str = String::from_utf8(json_bytes).context("Base64-decoded content is not valid UTF-8")?;
let json_value: serde_json::Value =
serde_json::from_str(&json_str).context("Failed to parse decoded --config-json-base64 as JSON")?;
*config = merge_json_into_config(config, json_value)
.context("Failed to merge --config-json-base64 with file config")?;
}
Ok(())
}
fn merge_json_into_config(
base_config: &kreuzberg::ExtractionConfig,
json_value: serde_json::Value,
) -> Result<kreuzberg::ExtractionConfig> {
kreuzberg::core::config::merge::merge_config_json(base_config, json_value).map_err(|e| anyhow::anyhow!("{}", e))
}
fn main() -> Result<()> {
let cli = Cli::parse();
let env_filter = if let Some(ref level) = cli.log_level {
EnvFilter::new(level)
} else {
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"))
};
let _ = tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_writer(std::io::stderr)
.try_init();
match cli.command {
Commands::Extract {
path,
config: config_path,
config_json,
config_json_base64,
mime_type,
format,
overrides,
} => {
validate_file_exists(&path)?;
overrides.validate()?;
let mut config = load_config(config_path)?;
apply_json_overrides(&mut config, config_json, config_json_base64)?;
overrides.apply(&mut config);
extract_command(path, config, mime_type, format)?;
}
Commands::ExtractStructured {
path,
schema,
model,
api_key,
prompt,
schema_name,
strict,
config,
format,
} => {
validate_file_exists(&path)?;
validate_file_exists(&schema)?;
extract_structured_command(ExtractStructuredArgs {
path,
schema_path: schema,
model,
api_key,
prompt,
schema_name,
strict,
config_path: config,
format,
})?;
}
Commands::Batch {
paths,
config: config_path,
config_json,
config_json_base64,
format,
overrides,
file_configs,
} => {
validate_batch_paths(&paths)?;
overrides.validate()?;
let mut config = load_config(config_path)?;
apply_json_overrides(&mut config, config_json, config_json_base64)?;
overrides.apply(&mut config);
let file_configs_map = if let Some(file_configs_path) = file_configs {
let file_configs_json = std::fs::read_to_string(&file_configs_path)
.with_context(|| format!("Failed to read file configs from '{}'", file_configs_path.display()))?;
let map: std::collections::HashMap<String, serde_json::Value> =
serde_json::from_str(&file_configs_json).with_context(|| {
format!(
"Failed to parse file configs JSON from '{}'",
file_configs_path.display()
)
})?;
Some(map)
} else {
None
};
batch_command(paths, file_configs_map, config, format)?;
}
Commands::Detect { path, format } => {
validate_file_exists(&path)?;
let path_str = path.to_string_lossy().to_string();
let mime_type = detect_mime_type(&path_str, true).with_context(|| {
format!(
"Failed to detect MIME type for file '{}'. Ensure the file is readable.",
path.display()
)
})?;
match format {
WireFormat::Text => {
println!("{}", style::success(&mime_type));
}
WireFormat::Json => {
let output = json!({
"path": path_str,
"mime_type": mime_type,
});
println!(
"{}",
serde_json::to_string_pretty(&output)
.context("Failed to serialize MIME type detection result to JSON")?
);
}
WireFormat::Toon => {
let output = json!({
"path": path_str,
"mime_type": mime_type,
});
println!(
"{}",
serde_toon::to_string(&output)
.context("Failed to serialize MIME type detection result to TOON")?
);
}
}
}
Commands::Formats { format } => {
let formats = kreuzberg::list_supported_formats();
match format {
WireFormat::Text => {
println!("{:<15} {}", style::label("EXTENSION"), style::label("MIME TYPE"));
println!("{}", style::dim(&format!("{:<15} ---------", "---------")));
for f in &formats {
println!("{:<15} {}", style::success(&format!(".{}", f.extension)), f.mime_type);
}
}
WireFormat::Json => {
println!(
"{}",
serde_json::to_string_pretty(&formats).context("Failed to serialize formats to JSON")?
);
}
WireFormat::Toon => {
println!(
"{}",
serde_toon::to_string(&formats).context("Failed to serialize formats to TOON")?
);
}
}
}
Commands::Version { format } => {
let version = env!("CARGO_PKG_VERSION");
let name = env!("CARGO_PKG_NAME");
match format {
WireFormat::Text => {
println!("{} {}", style::label(name), style::success(version));
}
WireFormat::Json => {
let output = json!({
"name": name,
"version": version,
});
println!(
"{}",
serde_json::to_string_pretty(&output)
.context("Failed to serialize version information to JSON")?
);
}
WireFormat::Toon => {
let output = json!({
"name": name,
"version": version,
});
println!(
"{}",
serde_toon::to_string(&output).context("Failed to serialize version information to TOON")?
);
}
}
}
#[cfg(feature = "api")]
Commands::Serve {
host: cli_host,
port: cli_port,
config: config_path,
} => {
let mut extraction_config = load_config(config_path.clone())?;
extraction_config.apply_env_overrides()?;
serve_command(cli_host, cli_port, extraction_config, config_path)?;
}
#[cfg(feature = "mcp")]
Commands::Mcp {
config: config_path,
transport,
#[cfg(feature = "mcp-http")]
host,
#[cfg(feature = "mcp-http")]
port,
#[cfg(not(feature = "mcp-http"))]
host,
#[cfg(not(feature = "mcp-http"))]
port,
} => {
let mut config = load_config(config_path)?;
config.apply_env_overrides()?;
mcp_command(config, transport, host, port)?;
}
Commands::Cache { command } => match command {
CacheCommands::Stats { cache_dir, format } => {
stats_command(cache_dir, format)?;
}
CacheCommands::Clear { cache_dir, format } => {
clear_command(cache_dir, format)?;
}
CacheCommands::Manifest { format } => {
manifest_command(format)?;
}
CacheCommands::Warm {
cache_dir,
format,
all_embeddings,
embedding_model,
all_table_models,
all_grammars,
grammar_groups,
grammars,
} => {
warm_command(
cache_dir,
format,
all_embeddings,
embedding_model,
all_table_models,
all_grammars,
grammar_groups,
grammars,
)?;
}
},
#[cfg(feature = "api")]
Commands::Api { command } => match command {
ApiCommands::Schema => {
println!("{}", kreuzberg::api::openapi::openapi_json());
}
},
#[cfg(feature = "embeddings")]
Commands::Embed {
text,
preset,
provider,
model,
api_key,
format,
} => {
let texts = if text.is_empty() {
vec![commands::read_stdin()?]
} else {
text
};
embed_command(texts, &preset, &provider, model, api_key, format)?;
}
Commands::Chunk {
text,
config: config_path,
chunk_size,
chunk_overlap,
chunker_type,
chunking_tokenizer,
topic_threshold,
format,
} => {
let input = match text {
Some(t) => t,
None => commands::read_stdin().context("No --text provided and failed to read from stdin")?,
};
validate_chunk_params(chunk_size, chunk_overlap)?;
let base_config = load_config(config_path)?;
let mut chunking_config = base_config.chunking.unwrap_or_default();
if let Some(size) = chunk_size {
chunking_config.max_characters = size;
if chunk_overlap.is_none() && chunking_config.overlap >= size {
chunking_config.overlap = size / 4;
}
}
if let Some(overlap) = chunk_overlap {
chunking_config.overlap = overlap;
}
match chunker_type.as_str() {
"markdown" => chunking_config.chunker_type = kreuzberg::ChunkerType::Markdown,
"yaml" => chunking_config.chunker_type = kreuzberg::ChunkerType::Yaml,
"semantic" => chunking_config.chunker_type = kreuzberg::ChunkerType::Semantic,
_ => chunking_config.chunker_type = kreuzberg::ChunkerType::Text,
}
if let Some(ref tokenizer) = chunking_tokenizer {
chunking_config.sizing = kreuzberg::ChunkSizing::Tokenizer {
model: tokenizer.clone(),
cache_dir: None,
};
}
if topic_threshold.is_some() {
chunking_config.topic_threshold = topic_threshold;
}
chunk_command(input, chunking_config, format)?;
}
Commands::Completions { shell } => {
let mut cmd = Cli::command();
clap_complete::generate(shell, &mut cmd, "kreuzberg", &mut std::io::stdout());
}
}
Ok(())
}