use thiserror::Error;
use super::{Segment, SegmentEmbedding};
#[derive(Debug, Error)]
pub enum EmbeddingError {
#[error("Failed to load model: {0}")]
ModelLoadError(String),
#[error("Inference failed: {0}")]
InferenceError(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Model not initialized")]
NotInitialized,
}
#[derive(Debug, Clone)]
pub struct EmbeddingModelConfig {
pub model_id: String,
pub embedding_dim: usize,
pub batch_size: usize,
pub use_gpu: bool,
}
impl Default for EmbeddingModelConfig {
fn default() -> Self {
Self {
model_id: "birdnet-v2.4".to_string(),
embedding_dim: 1024,
batch_size: 32,
use_gpu: false,
}
}
}
pub struct EmbeddingModel {
config: EmbeddingModelConfig,
}
impl EmbeddingModel {
pub async fn new(config: EmbeddingModelConfig) -> Result<Self, EmbeddingError> {
Ok(Self { config })
}
pub async fn embed_batch(
&self,
segments: &[Segment],
) -> Result<Vec<SegmentEmbedding>, EmbeddingError> {
let embeddings = segments
.iter()
.map(|seg| SegmentEmbedding {
id: seg.id,
recording_id: seg.recording_id,
embedding: vec![0.0; self.config.embedding_dim],
start_time: seg.start_time,
end_time: seg.end_time,
species: seg.species.clone(),
})
.collect();
Ok(embeddings)
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
}
Ok(vec![0.0; self.config.embedding_dim])
}
#[must_use]
pub fn embedding_dim(&self) -> usize {
self.config.embedding_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_embedding_model_creation() {
let model = EmbeddingModel::new(Default::default()).await;
assert!(model.is_ok());
}
#[tokio::test]
async fn test_embed_text_empty() {
let model = EmbeddingModel::new(Default::default()).await.unwrap();
let result = model.embed_text("").await;
assert!(matches!(result, Err(EmbeddingError::InvalidInput(_))));
}
}