use base64::{engine::general_purpose::STANDARD, Engine as _};
use log::debug;
use reqwest::Client;
use serde_json::{json, Value};
use std::error::Error;
#[derive(Debug, Clone)]
pub enum ImageSource {
Path(String),
Base64(String),
}
pub async fn extract(source: &ImageSource) -> Result<String, Box<dyn Error + Send + Sync>> {
match source {
ImageSource::Path(path) => extract_from_file(path).await,
ImageSource::Base64(data) => extract_from_base64(data).await,
}
}
async fn extract_from_file(path: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
let image_data = tokio::fs::read(path).await?;
let base64 = STANDARD.encode(&image_data);
call_google_vision(&base64).await
}
async fn extract_from_base64(data: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
call_google_vision(data).await
}
async fn call_google_vision(base64_image: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
let api_key = std::env::var("GOOGLE_API_KEY")
.map_err(|_| "GOOGLE_API_KEY environment variable not set")?;
let client = Client::new();
let url = format!(
"https://vision.googleapis.com/v1/images:annotate?key={}",
api_key
);
let request_body = json!({
"requests": [{
"image": {
"content": base64_image
},
"features": [{
"type": "TEXT_DETECTION"
}]
}]
});
debug!("Sending OCR request to Google Vision API");
let response = client
.post(&url)
.header("Accept-Encoding", "identity")
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(format!("Google Vision API error ({}): {}", status, error_text).into());
}
let response_body: Value = response.json().await?;
debug!("Google Vision API response: {:?}", response_body);
let text = response_body["responses"][0]["fullTextAnnotation"]["text"]
.as_str()
.ok_or("No text found in image")?
.to_string();
if text.trim().is_empty() {
return Err("No text detected in image".into());
}
debug!("Extracted text from image: {} characters", text.len());
Ok(text)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base64_encoding() {
let data = b"test data";
let encoded = STANDARD.encode(data);
assert!(!encoded.is_empty());
}
#[tokio::test]
async fn test_ocr_requires_api_key() {
let original_key = std::env::var("GOOGLE_API_KEY").ok();
std::env::remove_var("GOOGLE_API_KEY");
let fake_base64_image = STANDARD.encode(b"fake image data");
let result = call_google_vision(&fake_base64_image).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("GOOGLE_API_KEY"));
if let Some(key) = original_key {
std::env::set_var("GOOGLE_API_KEY", key);
}
}
}