use crate::InfernoError;
use crate::backends::InferenceParams;
use crate::multimodal::{
AudioFeatures, ModelCapabilities, MultiModalConfig, MultiModalProcessor, ProcessingStatus,
VisionFeatures,
};
use clap::{Args, Subcommand};
use serde_json;
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::time::{Duration, sleep};
#[derive(Args)]
pub struct MultiModalArgs {
#[command(subcommand)]
pub command: MultiModalCommands,
}
#[derive(Subcommand)]
pub enum MultiModalCommands {
Process {
#[arg(short, long)]
model: String,
#[arg(short, long)]
input: PathBuf,
#[arg(short, long)]
prompt: Option<String>,
#[arg(long, default_value = "500")]
max_tokens: u32,
#[arg(long, default_value = "0.7")]
temperature: f32,
#[arg(short, long, default_value = "text")]
output_format: String,
#[arg(short = 'o', long)]
output_file: Option<PathBuf>,
#[arg(long)]
_show_progress: bool,
},
ProcessBase64 {
#[arg(short, long)]
model: String,
#[arg(short, long)]
data: String,
#[arg(short = 't', long)]
media_type: String,
#[arg(short, long)]
prompt: Option<String>,
#[arg(long, default_value = "500")]
max_tokens: u32,
#[arg(long, default_value = "0.7")]
temperature: f32,
#[arg(short, long, default_value = "text")]
output_format: String,
},
Batch {
#[arg(short, long)]
model: String,
#[arg(short, long)]
input_dir: PathBuf,
#[arg(short, long, default_value = "*")]
pattern: String,
#[arg(short, long)]
prompt: Option<String>,
#[arg(short, long)]
output_dir: PathBuf,
#[arg(long, default_value = "3")]
max_concurrent: u32,
#[arg(long)]
continue_on_error: bool,
},
Sessions {
#[arg(short, long)]
detailed: bool,
#[arg(short, long)]
refresh: Option<u32>,
},
Status {
session_id: String,
#[arg(short, long)]
follow: bool,
},
Cancel {
session_id: String,
},
Formats {
#[arg(short, long)]
media_type: Option<String>,
#[arg(short, long)]
examples: bool,
},
Capabilities {
#[arg(short, long)]
model: Option<String>,
#[arg(short = 't', long)]
media_type: Option<String>,
#[arg(short, long, default_value = "table")]
format: String,
},
RegisterModel {
#[arg(short, long)]
model: String,
#[arg(short, long)]
config_file: PathBuf,
},
Analyze {
#[arg(short, long)]
input: PathBuf,
#[arg(short, long)]
detailed: bool,
#[arg(short, long, default_value = "table")]
format: String,
},
Convert {
#[arg(short, long)]
input: PathBuf,
#[arg(short, long)]
output: PathBuf,
#[arg(short, long)]
format: String,
#[arg(short, long)]
quality: Option<u32>,
#[arg(short, long)]
params: Vec<String>,
},
}
pub struct ProcessFileConfig {
pub model: String,
pub input: PathBuf,
pub prompt: Option<String>,
pub max_tokens: u32,
pub temperature: f32,
pub output_format: String,
pub output_file: Option<PathBuf>,
pub show_progress: bool,
}
pub struct ProcessBase64Config {
pub model: String,
pub data: String,
pub media_type: String,
pub prompt: Option<String>,
pub max_tokens: u32,
pub temperature: f32,
pub output_format: String,
}
pub struct BatchProcessConfig {
pub model: String,
pub input_dir: PathBuf,
pub pattern: String,
pub prompt: Option<String>,
pub output_dir: PathBuf,
pub max_concurrent: u32,
pub continue_on_error: bool,
}
pub async fn handle_multimodal_command(args: MultiModalArgs) -> Result<(), InfernoError> {
let config = MultiModalConfig::default();
let processor = MultiModalProcessor::new(config);
processor.initialize().await.map_err(|e| {
InfernoError::InvalidArgument(format!("Failed to initialize processor: {}", e))
})?;
match args.command {
MultiModalCommands::Process {
model,
input,
prompt,
max_tokens,
temperature,
output_format,
output_file,
_show_progress,
} => {
let config = ProcessFileConfig {
model,
input,
prompt,
max_tokens,
temperature,
output_format,
output_file,
show_progress: _show_progress,
};
handle_process_command(&processor, config).await
}
MultiModalCommands::ProcessBase64 {
model,
data,
media_type,
prompt,
max_tokens,
temperature,
output_format,
} => {
let config = ProcessBase64Config {
model,
data,
media_type,
prompt,
max_tokens,
temperature,
output_format,
};
handle_process_base64_command(&processor, config).await
}
MultiModalCommands::Batch {
model,
input_dir,
pattern,
prompt,
output_dir,
max_concurrent,
continue_on_error,
} => {
let config = BatchProcessConfig {
model,
input_dir,
pattern,
prompt,
output_dir,
max_concurrent,
continue_on_error,
};
handle_batch_command(&processor, config).await
}
MultiModalCommands::Sessions { detailed, refresh } => {
handle_sessions_command(&processor, detailed, refresh).await
}
MultiModalCommands::Status { session_id, follow } => {
handle_status_command(&processor, session_id, follow).await
}
MultiModalCommands::Cancel { session_id } => {
handle_cancel_command(&processor, session_id).await
}
MultiModalCommands::Formats {
media_type,
examples,
} => handle_formats_command(&processor, media_type, examples).await,
MultiModalCommands::Capabilities {
model,
media_type,
format,
} => handle_capabilities_command(&processor, model, media_type, format).await,
MultiModalCommands::RegisterModel { model, config_file } => {
handle_register_model_command(&processor, model, config_file).await
}
MultiModalCommands::Analyze {
input,
detailed,
format,
} => handle_analyze_command(&processor, input, detailed, format).await,
MultiModalCommands::Convert {
input,
output,
format,
quality,
params,
} => handle_convert_command(input, output, format, quality, params).await,
}
}
async fn handle_process_command(
processor: &MultiModalProcessor,
config: ProcessFileConfig,
) -> Result<(), InfernoError> {
if config.model.is_empty() {
return Err(InfernoError::InvalidArgument(
"Model name cannot be empty".to_string(),
));
}
if !config.input.exists() {
return Err(InfernoError::InvalidArgument(format!(
"Input file does not exist: {:?}",
config.input
)));
}
println!("Processing file: {:?}", config.input);
let params = InferenceParams {
max_tokens: config.max_tokens,
temperature: config.temperature,
top_k: 40,
top_p: 0.9,
stream: false,
stop_sequences: vec![],
seed: None,
};
let result = processor
.process_file(&config.model, &config.input, config.prompt, params)
.await
.map_err(|e| InfernoError::InvalidArgument(format!("Processing failed: {}", e)))?;
let output_content = match config.output_format.as_str() {
"json" => serde_json::to_string_pretty(&result).map_err(|e| {
InfernoError::InvalidArgument(format!("JSON serialization failed: {}", e))
})?,
"text" => format_text_output(&result),
_ => {
return Err(InfernoError::InvalidArgument(format!(
"Unsupported output format: {}",
config.output_format
)));
}
};
if let Some(output_path) = config.output_file {
tokio::fs::write(&output_path, &output_content).await?;
println!("Results saved to: {:?}", output_path);
} else {
println!("{}", output_content);
}
Ok(())
}
async fn handle_process_base64_command(
processor: &MultiModalProcessor,
config: ProcessBase64Config,
) -> Result<(), InfernoError> {
if config.model.is_empty() {
return Err(InfernoError::InvalidArgument(
"Model name cannot be empty".to_string(),
));
}
if config.data.is_empty() {
return Err(InfernoError::InvalidArgument(
"Base64 data cannot be empty".to_string(),
));
}
if config.media_type.is_empty() {
return Err(InfernoError::InvalidArgument(
"Media type cannot be empty".to_string(),
));
}
let supported_types = [
"jpg", "jpeg", "png", "gif", "webp", "bmp", "tiff", "mp3", "wav", "flac", "ogg", "m4a", "aac", "mp4", "avi", "mov", "mkv", "webm", ];
if !supported_types.contains(&config.media_type.to_lowercase().as_str()) {
return Err(InfernoError::InvalidArgument(format!(
"Unsupported media type '{}'. Supported types: {}",
config.media_type,
supported_types.join(", ")
)));
}
println!("Processing base64 data ({} type)", config.media_type);
let params = InferenceParams {
max_tokens: config.max_tokens,
temperature: config.temperature,
top_k: 40,
top_p: 0.9,
stream: false,
stop_sequences: vec![],
seed: None,
};
let result = processor
.process_base64(
&config.model,
&config.data,
&config.media_type,
config.prompt,
params,
)
.await
.map_err(|e| InfernoError::InvalidArgument(format!("Processing failed: {}", e)))?;
let output_content = match config.output_format.as_str() {
"json" => serde_json::to_string_pretty(&result).map_err(|e| {
InfernoError::InvalidArgument(format!("JSON serialization failed: {}", e))
})?,
"text" => format_text_output(&result),
_ => {
return Err(InfernoError::InvalidArgument(format!(
"Unsupported output format: {}",
config.output_format
)));
}
};
println!("{}", output_content);
Ok(())
}
async fn handle_batch_command(
processor: &MultiModalProcessor,
config: BatchProcessConfig,
) -> Result<(), InfernoError> {
if config.model.is_empty() {
return Err(InfernoError::InvalidArgument(
"Model name cannot be empty".to_string(),
));
}
if !config.input_dir.exists() {
return Err(InfernoError::InvalidArgument(format!(
"Input directory does not exist: {:?}",
config.input_dir
)));
}
if !config.input_dir.is_dir() {
return Err(InfernoError::InvalidArgument(format!(
"Input path is not a directory: {:?}",
config.input_dir
)));
}
if config.max_concurrent == 0 {
return Err(InfernoError::InvalidArgument(
"Max concurrent jobs must be at least 1".to_string(),
));
}
println!("Starting batch processing...");
println!("Input directory: {:?}", config.input_dir);
println!("Pattern: {}", config.pattern);
println!("Output directory: {:?}", config.output_dir);
println!("Max concurrent jobs: {}", config.max_concurrent);
tokio::fs::create_dir_all(&config.output_dir).await?;
let files = find_matching_files(&config.input_dir, &config.pattern).await?;
println!("Found {} files to process", files.len());
if files.is_empty() {
println!("No files found matching pattern");
return Ok(());
}
let params = InferenceParams {
max_tokens: 500,
temperature: 0.7,
top_k: 40,
top_p: 0.9,
stream: false,
stop_sequences: vec![],
seed: None,
};
let mut processed = 0;
let mut failed = 0;
let continue_on_error = config.continue_on_error;
for chunk in files.chunks(config.max_concurrent as usize) {
let mut tasks = Vec::new();
for file_path in chunk {
let processor_ref = processor;
let model_ref = &config.model;
let prompt_ref = &config.prompt;
let params_ref = params.clone();
let output_dir_ref = &config.output_dir;
let task = async move {
let result = processor_ref
.process_file(model_ref, file_path, prompt_ref.clone(), params_ref)
.await;
match result {
Ok(res) => {
let file_stem = file_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let output_file = output_dir_ref.join(format!("{}_result.json", file_stem));
if let Err(e) = tokio::fs::write(
&output_file,
serde_json::to_string_pretty(&res).unwrap_or_default(),
)
.await
{
eprintln!("Failed to save result for {:?}: {}", file_path, e);
return false;
}
println!("✅ Processed: {:?} -> {:?}", file_path, output_file);
true
}
Err(e) => {
eprintln!("❌ Failed to process {:?}: {}", file_path, e);
!continue_on_error
}
}
};
tasks.push(task);
}
for task in tasks {
if task.await {
processed += 1;
} else {
failed += 1;
if !continue_on_error {
break;
}
}
}
if failed > 0 && !continue_on_error {
break;
}
}
println!("\nBatch processing completed:");
println!("✅ Processed: {}", processed);
println!("❌ Failed: {}", failed);
Ok(())
}
async fn handle_sessions_command(
processor: &MultiModalProcessor,
detailed: bool,
refresh: Option<u32>,
) -> Result<(), InfernoError> {
if let Some(interval) = refresh {
loop {
print!("\x1B[2J\x1B[1;1H"); println!(
"🔄 Active Processing Sessions (refreshing every {}s)",
interval
);
println!("{}", "=".repeat(60));
display_sessions(processor, detailed).await?;
sleep(Duration::from_secs(interval as u64)).await;
}
} else {
display_sessions(processor, detailed).await?;
}
Ok(())
}
async fn display_sessions(
processor: &MultiModalProcessor,
detailed: bool,
) -> Result<(), InfernoError> {
let sessions = processor
.list_active_sessions()
.await
.map_err(|e| InfernoError::InvalidArgument(format!("Failed to list sessions: {}", e)))?;
if sessions.is_empty() {
println!("No active processing sessions");
return Ok(());
}
if detailed {
for session in sessions {
println!("\n📋 Session Details:");
println!("ID: {}", session.id);
println!("Model: {}", session.model_id);
println!("Status: {:?}", session.status);
println!("Progress: {:.1}%", session.progress);
println!(
"Created: {}",
session.created_at.format("%Y-%m-%d %H:%M:%S")
);
println!(
"Updated: {}",
session.updated_at.format("%Y-%m-%d %H:%M:%S")
);
}
} else {
println!(
"{:<15} {:<20} {:<12} {:<8} {:<20}",
"Session ID", "Model", "Status", "Progress", "Created"
);
println!("{}", "-".repeat(80));
for session in sessions {
let short_id = if session.id.len() > 12 {
format!("{}...", &session.id[..9])
} else {
session.id
};
println!(
"{:<15} {:<20} {:<12} {:<8} {:<20}",
short_id,
session.model_id,
format!("{:?}", session.status),
format!("{:.1}%", session.progress),
session.created_at.format("%H:%M:%S")
);
}
}
Ok(())
}
async fn handle_status_command(
processor: &MultiModalProcessor,
session_id: String,
follow: bool,
) -> Result<(), InfernoError> {
if follow {
loop {
match processor.get_session_status(&session_id).await {
Ok(Some(session)) => {
print!(
"\r🔄 Session {} - Status: {:?} - Progress: {:.1}%",
&session_id[..8.min(session_id.len())],
session.status,
session.progress
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
if matches!(
session.status,
ProcessingStatus::Completed
| ProcessingStatus::Failed
| ProcessingStatus::Cancelled
) {
println!();
match session.status {
ProcessingStatus::Completed => {
println!("✅ Processing completed successfully!")
}
ProcessingStatus::Failed => println!("❌ Processing failed!"),
ProcessingStatus::Cancelled => println!("⚠️ Processing was cancelled"),
_ => {}
}
break;
}
}
Ok(None) => {
println!("Session not found: {}", session_id);
break;
}
Err(e) => {
println!("Error getting session status: {}", e);
break;
}
}
sleep(Duration::from_secs(2)).await;
}
} else {
match processor.get_session_status(&session_id).await {
Ok(Some(session)) => {
println!("Session ID: {}", session.id);
println!("Model: {}", session.model_id);
println!("Status: {:?}", session.status);
println!("Progress: {:.1}%", session.progress);
println!(
"Created: {}",
session.created_at.format("%Y-%m-%d %H:%M:%S")
);
println!(
"Updated: {}",
session.updated_at.format("%Y-%m-%d %H:%M:%S")
);
}
Ok(None) => {
println!("Session not found: {}", session_id);
}
Err(e) => {
return Err(InfernoError::InvalidArgument(format!(
"Failed to get session status: {}",
e
)));
}
}
}
Ok(())
}
async fn handle_cancel_command(
processor: &MultiModalProcessor,
session_id: String,
) -> Result<(), InfernoError> {
if session_id.is_empty() {
return Err(InfernoError::InvalidArgument(
"Session ID cannot be empty".to_string(),
));
}
processor
.cancel_session(&session_id)
.await
.map_err(|e| InfernoError::InvalidArgument(format!("Failed to cancel session: {}", e)))?;
println!("Session {} has been cancelled", session_id);
Ok(())
}
async fn handle_formats_command(
processor: &MultiModalProcessor,
media_type: Option<String>,
examples: bool,
) -> Result<(), InfernoError> {
let formats = processor.get_supported_formats();
match media_type {
Some(mt) => {
if let Some(format_list) = formats.get(&mt) {
println!("Supported {} formats:", mt);
for format in format_list {
println!(" • {}", format);
}
} else {
println!("Unknown media type: {}", mt);
println!(
"Available types: {}",
formats.keys().cloned().collect::<Vec<_>>().join(", ")
);
}
}
None => {
println!("📁 Supported Media Formats:");
println!("{}", "=".repeat(40));
for (media_type, format_list) in formats {
println!("\n🎯 {}:", media_type.to_uppercase());
let chunks: Vec<_> = format_list.chunks(6).collect();
for chunk in chunks {
println!(" {}", chunk.join(", "));
}
}
}
}
if examples {
println!("\n💡 Usage Examples:");
println!("{}", "=".repeat(20));
println!("# Process an image with text prompt");
println!(
"inferno multimodal process -m gpt-4-vision -i image.jpg -p \"What's in this image?\""
);
println!("\n# Batch process audio files");
println!(
"inferno multimodal batch -m whisper-large -i ./audio/ -p transcribe -o ./results/"
);
println!("\n# Analyze video metadata");
println!("inferno multimodal analyze -i video.mp4 --detailed");
}
Ok(())
}
async fn handle_capabilities_command(
_processor: &MultiModalProcessor,
model: Option<String>,
media_type: Option<String>,
format: String,
) -> Result<(), InfernoError> {
let mut capabilities = HashMap::new();
capabilities.insert(
"gpt-4-vision".to_string(),
ModelCapabilities {
supports_text: true,
supports_images: true,
supports_audio: false,
supports_video: false,
max_context_length: Some(8000),
supported_languages: vec!["en".to_string(), "es".to_string(), "fr".to_string()],
vision_features: Some(VisionFeatures {
object_detection: true,
ocr: true,
scene_understanding: true,
face_recognition: false,
image_generation: false,
max_image_size: (2048, 2048),
}),
audio_features: None,
},
);
capabilities.insert(
"whisper-large".to_string(),
ModelCapabilities {
supports_text: true,
supports_images: false,
supports_audio: true,
supports_video: false,
max_context_length: None,
supported_languages: vec![
"en".to_string(),
"es".to_string(),
"fr".to_string(),
"de".to_string(),
],
vision_features: None,
audio_features: Some(AudioFeatures {
speech_to_text: true,
audio_classification: false,
music_analysis: false,
voice_synthesis: false,
noise_reduction: true,
max_audio_length_seconds: 3600,
}),
},
);
let filtered_capabilities: HashMap<String, ModelCapabilities> = match (&model, &media_type) {
(Some(m), _) => capabilities
.into_iter()
.filter(|(name, _)| name == m)
.collect(),
(None, Some(mt)) => capabilities
.into_iter()
.filter(|(_, caps)| match mt.as_str() {
"text" => caps.supports_text,
"image" => caps.supports_images,
"audio" => caps.supports_audio,
"video" => caps.supports_video,
_ => false,
})
.collect(),
(None, None) => capabilities,
};
if filtered_capabilities.is_empty() {
println!("No models found matching criteria");
return Ok(());
}
match format.as_str() {
"json" => {
println!(
"{}",
serde_json::to_string_pretty(&filtered_capabilities).map_err(|e| {
InfernoError::InvalidArgument(format!("JSON serialization failed: {}", e))
})?
);
}
"table" => {
println!("🤖 Model Capabilities:");
println!("{}", "=".repeat(80));
for (model_name, caps) in filtered_capabilities {
println!("\n📋 Model: {}", model_name);
println!(" Text: {}", if caps.supports_text { "✅" } else { "❌" });
println!(
" Images: {}",
if caps.supports_images { "✅" } else { "❌" }
);
println!(
" Audio: {}",
if caps.supports_audio { "✅" } else { "❌" }
);
println!(
" Video: {}",
if caps.supports_video { "✅" } else { "❌" }
);
if let Some(max_ctx) = caps.max_context_length {
println!(" Max Context: {} tokens", max_ctx);
}
if !caps.supported_languages.is_empty() {
println!(" Languages: {}", caps.supported_languages.join(", "));
}
if let Some(vision) = caps.vision_features {
println!(" Vision Features:");
println!(
" • Object Detection: {}",
if vision.object_detection {
"✅"
} else {
"❌"
}
);
println!(" • OCR: {}", if vision.ocr { "✅" } else { "❌" });
println!(
" • Scene Understanding: {}",
if vision.scene_understanding {
"✅"
} else {
"❌"
}
);
println!(
" • Max Image Size: {}x{}",
vision.max_image_size.0, vision.max_image_size.1
);
}
if let Some(audio) = caps.audio_features {
println!(" Audio Features:");
println!(
" • Speech to Text: {}",
if audio.speech_to_text { "✅" } else { "❌" }
);
println!(
" • Audio Classification: {}",
if audio.audio_classification {
"✅"
} else {
"❌"
}
);
println!(
" • Music Analysis: {}",
if audio.music_analysis { "✅" } else { "❌" }
);
println!(
" • Max Audio Length: {}s",
audio.max_audio_length_seconds
);
}
}
}
_ => {
return Err(InfernoError::InvalidArgument(format!(
"Unsupported format: {}",
format
)));
}
}
Ok(())
}
async fn handle_register_model_command(
processor: &MultiModalProcessor,
model: String,
config_file: PathBuf,
) -> Result<(), InfernoError> {
if model.is_empty() {
return Err(InfernoError::InvalidArgument(
"Model name cannot be empty".to_string(),
));
}
if !config_file.exists() {
return Err(InfernoError::InvalidArgument(format!(
"Config file does not exist: {:?}",
config_file
)));
}
let config_content = tokio::fs::read_to_string(&config_file).await?;
let capabilities: ModelCapabilities = serde_json::from_str(&config_content)
.map_err(|e| InfernoError::InvalidArgument(format!("Invalid JSON config: {}", e)))?;
processor
.register_model_capabilities(model.clone(), capabilities)
.await
.map_err(|e| InfernoError::InvalidArgument(format!("Failed to register model: {}", e)))?;
println!("✅ Model '{}' capabilities registered successfully", model);
Ok(())
}
async fn handle_analyze_command(
_processor: &MultiModalProcessor,
input: PathBuf,
detailed: bool,
format: String,
) -> Result<(), InfernoError> {
if !input.exists() {
return Err(InfernoError::InvalidArgument(format!(
"Input file does not exist: {:?}",
input
)));
}
println!("Analyzing file: {:?}", input);
let file_metadata = tokio::fs::metadata(&input).await?;
let file_extension = input
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("unknown");
let analysis = create_mock_analysis(&input, &file_metadata, file_extension, detailed);
match format.as_str() {
"json" => {
println!(
"{}",
serde_json::to_string_pretty(&analysis).map_err(|e| {
InfernoError::InvalidArgument(format!("JSON serialization failed: {}", e))
})?
);
}
"table" => {
print_analysis_table(&analysis);
}
_ => {
return Err(InfernoError::InvalidArgument(format!(
"Unsupported format: {}",
format
)));
}
}
Ok(())
}
async fn handle_convert_command(
input: PathBuf,
output: PathBuf,
format: String,
quality: Option<u32>,
params: Vec<String>,
) -> Result<(), InfernoError> {
if !input.exists() {
return Err(InfernoError::InvalidArgument(format!(
"Input file does not exist: {:?}",
input
)));
}
if format.is_empty() {
return Err(InfernoError::InvalidArgument(
"Target format cannot be empty".to_string(),
));
}
if let Some(q) = quality {
if q > 100 {
return Err(InfernoError::InvalidArgument(format!(
"Quality must be between 0 and 100, got: {}",
q
)));
}
}
println!(
"Converting {:?} to {:?} (format: {})",
input, output, format
);
if let Some(q) = quality {
println!("Quality setting: {}", q);
}
if !params.is_empty() {
println!("Additional parameters: {:?}", params);
}
let input_data = tokio::fs::read(&input).await?;
println!("🔄 Converting...");
tokio::time::sleep(Duration::from_millis(500)).await;
tokio::fs::write(&output, &input_data).await?;
println!("✅ Conversion completed: {:?}", output);
Ok(())
}
async fn find_matching_files(dir: &PathBuf, pattern: &str) -> Result<Vec<PathBuf>, InfernoError> {
let mut files = Vec::new();
let mut entries = tokio::fs::read_dir(dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.is_file() && (pattern == "*" || path.to_string_lossy().contains(pattern)) {
files.push(path);
}
}
Ok(files)
}
fn format_text_output(result: &crate::multimodal::MultiModalResult) -> String {
let mut output = String::new();
output.push_str("🔥 Multi-Modal Processing Result\n");
output.push_str(&format!("{}\n", "=".repeat(40)));
output.push_str(&format!("Session ID: {}\n", result.id));
output.push_str(&format!("Model: {}\n", result.model_used));
output.push_str(&format!("Input: {}\n", result.input_summary));
output.push_str(&format!(
"Processing Time: {}ms\n",
result.processing_time_ms
));
output.push_str(&format!(
"Created: {}\n\n",
result.created_at.format("%Y-%m-%d %H:%M:%S")
));
output.push_str("📋 Processed Components:\n");
for (i, component) in result.processed_components.iter().enumerate() {
output.push_str(&format!(
" {}. {} - {}\n",
i + 1,
component.component_type,
component.description
));
if component.processing_time_ms > 0 {
output.push_str(&format!(
" Processing time: {}ms\n",
component.processing_time_ms
));
}
}
output.push_str(&format!("\n🎯 Result:\n{}\n", result.inference_result));
if let Some(scores) = &result.confidence_scores {
output.push_str("\n📊 Confidence Scores:\n");
for (key, score) in scores {
output.push_str(&format!(" • {}: {:.2}\n", key, score));
}
}
output
}
fn create_mock_analysis(
path: &PathBuf,
metadata: &std::fs::Metadata,
extension: &str,
detailed: bool,
) -> serde_json::Value {
let mut analysis = serde_json::json!({
"file_path": path,
"file_size_bytes": metadata.len(),
"file_extension": extension,
"file_type": determine_file_type(extension),
"last_modified": metadata.modified().ok().map(|t| {
chrono::DateTime::<chrono::Utc>::from(t).format("%Y-%m-%d %H:%M:%S").to_string()
})
});
if detailed {
match extension {
"jpg" | "jpeg" | "png" | "bmp" | "gif" | "webp" | "tiff" => {
analysis["image_analysis"] = serde_json::json!({
"format": extension.to_uppercase(),
"estimated_dimensions": "1920x1080",
"estimated_channels": 3,
"color_space": "RGB",
"has_exif": true
});
}
"mp3" | "wav" | "flac" | "ogg" | "m4a" | "aac" => {
analysis["audio_analysis"] = serde_json::json!({
"format": extension.to_uppercase(),
"estimated_duration": "120.5 seconds",
"estimated_sample_rate": "44100 Hz",
"estimated_channels": 2,
"estimated_bitrate": "128 kbps"
});
}
"mp4" | "avi" | "mov" | "mkv" | "webm" => {
analysis["video_analysis"] = serde_json::json!({
"format": extension.to_uppercase(),
"estimated_duration": "300.0 seconds",
"estimated_resolution": "1920x1080",
"estimated_frame_rate": "30 fps",
"has_audio": true,
"estimated_video_codec": "H.264",
"estimated_audio_codec": "AAC"
});
}
_ => {
analysis["file_analysis"] = serde_json::json!({
"type": "unknown",
"mime_type": "application/octet-stream"
});
}
}
}
analysis
}
fn determine_file_type(extension: &str) -> &'static str {
match extension.to_lowercase().as_str() {
"jpg" | "jpeg" | "png" | "bmp" | "gif" | "webp" | "tiff" => "image",
"mp3" | "wav" | "flac" | "ogg" | "m4a" | "aac" => "audio",
"mp4" | "avi" | "mov" | "mkv" | "webm" => "video",
"txt" | "md" | "json" | "xml" | "csv" => "text",
_ => "unknown",
}
}
fn print_analysis_table(analysis: &serde_json::Value) {
println!("📊 File Analysis Results:");
println!("{}", "=".repeat(50));
if let Some(path) = analysis.get("file_path") {
println!("File: {}", path.as_str().unwrap_or("unknown"));
}
if let Some(size) = analysis.get("file_size_bytes") {
let size_mb = size.as_u64().unwrap_or(0) as f64 / 1024.0 / 1024.0;
println!(
"Size: {:.2} MB ({} bytes)",
size_mb,
size.as_u64().unwrap_or(0)
);
}
if let Some(file_type) = analysis.get("file_type") {
println!("Type: {}", file_type.as_str().unwrap_or("unknown"));
}
if let Some(ext) = analysis.get("file_extension") {
println!("Extension: {}", ext.as_str().unwrap_or("unknown"));
}
if let Some(modified) = analysis.get("last_modified") {
println!("Modified: {}", modified.as_str().unwrap_or("unknown"));
}
if let Some(img_analysis) = analysis.get("image_analysis") {
println!("\n🖼️ Image Analysis:");
if let Some(dims) = img_analysis.get("estimated_dimensions") {
println!(" Dimensions: {}", dims.as_str().unwrap_or("unknown"));
}
if let Some(channels) = img_analysis.get("estimated_channels") {
println!(" Channels: {}", channels.as_u64().unwrap_or(0));
}
if let Some(color_space) = img_analysis.get("color_space") {
println!(
" Color Space: {}",
color_space.as_str().unwrap_or("unknown")
);
}
}
if let Some(audio_analysis) = analysis.get("audio_analysis") {
println!("\n🎵 Audio Analysis:");
if let Some(duration) = audio_analysis.get("estimated_duration") {
println!(" Duration: {}", duration.as_str().unwrap_or("unknown"));
}
if let Some(sample_rate) = audio_analysis.get("estimated_sample_rate") {
println!(
" Sample Rate: {}",
sample_rate.as_str().unwrap_or("unknown")
);
}
if let Some(channels) = audio_analysis.get("estimated_channels") {
println!(" Channels: {}", channels.as_u64().unwrap_or(0));
}
if let Some(bitrate) = audio_analysis.get("estimated_bitrate") {
println!(" Bitrate: {}", bitrate.as_str().unwrap_or("unknown"));
}
}
if let Some(video_analysis) = analysis.get("video_analysis") {
println!("\n🎬 Video Analysis:");
if let Some(duration) = video_analysis.get("estimated_duration") {
println!(" Duration: {}", duration.as_str().unwrap_or("unknown"));
}
if let Some(resolution) = video_analysis.get("estimated_resolution") {
println!(" Resolution: {}", resolution.as_str().unwrap_or("unknown"));
}
if let Some(frame_rate) = video_analysis.get("estimated_frame_rate") {
println!(" Frame Rate: {}", frame_rate.as_str().unwrap_or("unknown"));
}
if let Some(video_codec) = video_analysis.get("estimated_video_codec") {
println!(
" Video Codec: {}",
video_codec.as_str().unwrap_or("unknown")
);
}
if let Some(audio_codec) = video_analysis.get("estimated_audio_codec") {
println!(
" Audio Codec: {}",
audio_codec.as_str().unwrap_or("unknown")
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_process_file_config_empty_model_validation() {
let config = ProcessFileConfig {
model: "".to_string(),
input: PathBuf::from("/tmp/test.jpg"),
prompt: None,
max_tokens: 500,
temperature: 0.7,
output_format: "text".to_string(),
output_file: None,
show_progress: false,
};
assert!(config.model.is_empty());
}
#[test]
fn test_process_base64_config_empty_data_validation() {
let config = ProcessBase64Config {
model: "test-model".to_string(),
data: "".to_string(),
media_type: "jpg".to_string(),
prompt: None,
max_tokens: 500,
temperature: 0.7,
output_format: "text".to_string(),
};
assert!(config.data.is_empty());
}
#[test]
fn test_batch_process_config_zero_concurrent_validation() {
let config = BatchProcessConfig {
model: "test-model".to_string(),
input_dir: PathBuf::from("/tmp"),
pattern: "*".to_string(),
prompt: None,
output_dir: PathBuf::from("/tmp/output"),
max_concurrent: 0,
continue_on_error: false,
};
assert_eq!(config.max_concurrent, 0);
}
#[test]
fn test_determine_file_type_image() {
assert_eq!(determine_file_type("jpg"), "image");
assert_eq!(determine_file_type("jpeg"), "image");
assert_eq!(determine_file_type("png"), "image");
assert_eq!(determine_file_type("gif"), "image");
assert_eq!(determine_file_type("webp"), "image");
}
#[test]
fn test_determine_file_type_audio() {
assert_eq!(determine_file_type("mp3"), "audio");
assert_eq!(determine_file_type("wav"), "audio");
assert_eq!(determine_file_type("flac"), "audio");
assert_eq!(determine_file_type("ogg"), "audio");
}
#[test]
fn test_determine_file_type_video() {
assert_eq!(determine_file_type("mp4"), "video");
assert_eq!(determine_file_type("avi"), "video");
assert_eq!(determine_file_type("mov"), "video");
assert_eq!(determine_file_type("mkv"), "video");
}
#[test]
fn test_determine_file_type_text() {
assert_eq!(determine_file_type("txt"), "text");
assert_eq!(determine_file_type("md"), "text");
assert_eq!(determine_file_type("json"), "text");
}
#[test]
fn test_determine_file_type_unknown() {
assert_eq!(determine_file_type("xyz"), "unknown");
assert_eq!(determine_file_type(""), "unknown");
}
#[test]
fn test_create_mock_analysis_basic() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
let metadata = std::fs::metadata(&path).unwrap();
let analysis = create_mock_analysis(&path, &metadata, "jpg", false);
assert!(analysis.get("file_path").is_some());
assert!(analysis.get("file_size_bytes").is_some());
assert!(analysis.get("file_extension").is_some());
assert!(analysis.get("file_type").is_some());
}
#[test]
fn test_create_mock_analysis_detailed_image() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
let metadata = std::fs::metadata(&path).unwrap();
let analysis = create_mock_analysis(&path, &metadata, "png", true);
assert!(analysis.get("image_analysis").is_some());
let img_analysis = analysis.get("image_analysis").unwrap();
assert!(img_analysis.get("format").is_some());
assert!(img_analysis.get("estimated_dimensions").is_some());
}
#[test]
fn test_create_mock_analysis_detailed_audio() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
let metadata = std::fs::metadata(&path).unwrap();
let analysis = create_mock_analysis(&path, &metadata, "mp3", true);
assert!(analysis.get("audio_analysis").is_some());
let audio_analysis = analysis.get("audio_analysis").unwrap();
assert!(audio_analysis.get("format").is_some());
assert!(audio_analysis.get("estimated_duration").is_some());
}
#[test]
fn test_create_mock_analysis_detailed_video() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
let metadata = std::fs::metadata(&path).unwrap();
let analysis = create_mock_analysis(&path, &metadata, "mp4", true);
assert!(analysis.get("video_analysis").is_some());
let video_analysis = analysis.get("video_analysis").unwrap();
assert!(video_analysis.get("format").is_some());
assert!(video_analysis.get("estimated_resolution").is_some());
}
#[test]
fn test_supported_media_types() {
let supported_types = [
"jpg", "jpeg", "png", "gif", "webp", "bmp", "tiff", "mp3", "wav", "flac", "ogg", "m4a", "aac", "mp4", "avi", "mov", "mkv", "webm", ];
for media_type in &supported_types {
assert!(
supported_types.contains(media_type),
"Type {} should be supported",
media_type
);
}
assert!(!supported_types.contains(&"invalid"));
}
}