use std::sync::Arc;
use async_trait::async_trait;
use crate::error::HirnResult;
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Embedding {
pub vector: Vec<f32>,
pub model_id: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MultivectorEmbedding {
pub vectors: Vec<Vec<f32>>,
pub model_id: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RerankResult {
pub index: usize,
pub score: f32,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ExtractedEntity {
pub name: String,
pub entity_type: String,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ExtractedRelation {
pub source: String,
pub target: String,
pub relation_type: String,
pub weight: f32,
}
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
fn dimensions(&self) -> usize;
fn model_id(&self) -> &str;
fn max_input_tokens(&self) -> usize;
async fn embed_multivec(&self, _texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
Err(crate::error::HirnError::InvalidInput(
"this embedder does not support multivector embeddings".into(),
))
}
fn supports_multivec(&self) -> bool {
false
}
}
#[async_trait]
impl<T: Embedder + ?Sized> Embedder for Arc<T> {
async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
self.as_ref().embed(texts).await
}
fn dimensions(&self) -> usize {
self.as_ref().dimensions()
}
fn model_id(&self) -> &str {
self.as_ref().model_id()
}
fn max_input_tokens(&self) -> usize {
self.as_ref().max_input_tokens()
}
async fn embed_multivec(&self, texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
self.as_ref().embed_multivec(texts).await
}
fn supports_multivec(&self) -> bool {
self.as_ref().supports_multivec()
}
}
pub trait TokenCounter: Send + Sync {
fn count_tokens(&self, text: &str) -> usize;
fn count_tokens_batch(&self, texts: &[&str]) -> Vec<usize> {
texts.iter().map(|t| self.count_tokens(t)).collect()
}
}
#[derive(Debug, Clone, Copy)]
pub struct CharEstimateCounter;
impl TokenCounter for CharEstimateCounter {
fn count_tokens(&self, text: &str) -> usize {
text.len().div_ceil(4)
}
}
#[async_trait]
pub trait Reranker: Send + Sync {
async fn rerank(
&self,
query: &str,
documents: &[&str],
top_k: usize,
) -> HirnResult<Vec<RerankResult>>;
}
#[derive(Debug, Clone, Copy)]
pub struct NoopReranker;
#[async_trait]
impl Reranker for NoopReranker {
async fn rerank(
&self,
_query: &str,
documents: &[&str],
top_k: usize,
) -> HirnResult<Vec<RerankResult>> {
Ok(documents
.iter()
.enumerate()
.take(top_k)
.map(|(i, _)| RerankResult {
index: i,
score: 1.0 - (i as f32 / documents.len().max(1) as f32),
})
.collect())
}
}
#[async_trait]
pub trait EntityExtractor: Send + Sync {
async fn extract_entities(
&self,
text: &str,
entity_types: &[&str],
) -> HirnResult<Vec<ExtractedEntity>>;
async fn extract_relations(
&self,
text: &str,
entities: &[ExtractedEntity],
) -> HirnResult<Vec<ExtractedRelation>>;
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum ResponseFormat {
#[default]
Text,
JsonObject,
JsonSchema(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}
impl TokenUsage {
pub const fn total(&self) -> u32 {
self.prompt_tokens + self.completion_tokens
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LlmResponse {
pub content: String,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LlmChunk {
pub delta: String,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LlmOptions {
pub model_override: Option<String>,
pub temperature: f32,
pub max_tokens: u32,
pub response_format: ResponseFormat,
}
impl Default for LlmOptions {
fn default() -> Self {
Self {
model_override: None,
temperature: 0.0,
max_tokens: 1024,
response_format: ResponseFormat::Text,
}
}
}
pub type LlmStream = std::pin::Pin<Box<dyn futures::Stream<Item = HirnResult<LlmChunk>> + Send>>;
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn generate_text(
&self,
messages: &[ChatMessage],
options: &LlmOptions,
) -> HirnResult<String>;
async fn generate(
&self,
messages: &[ChatMessage],
options: &LlmOptions,
) -> HirnResult<LlmResponse> {
let content = self.generate_text(messages, options).await?;
Ok(LlmResponse {
content,
usage: None,
})
}
async fn generate_stream(
&self,
messages: &[ChatMessage],
options: &LlmOptions,
) -> HirnResult<LlmStream> {
let text = self.generate_text(messages, options).await?;
let chunk = LlmChunk {
delta: text,
usage: None,
};
Ok(Box::pin(futures::stream::once(async { Ok(chunk) })))
}
fn model_id(&self) -> &str;
}
#[async_trait]
pub trait AsymmetricEmbedder: Send + Sync {
async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
self.embed_source(texts).await
}
fn name(&self) -> &str;
fn dims(&self) -> usize;
}
pub struct EmbedderAdapter<E: Embedder> {
inner: E,
}
impl<E: Embedder> EmbedderAdapter<E> {
pub fn new(inner: E) -> Self {
Self { inner }
}
}
#[async_trait]
impl<E: Embedder> AsymmetricEmbedder for EmbedderAdapter<E> {
async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
self.inner.embed(texts).await
}
fn name(&self) -> &str {
self.inner.model_id()
}
fn dims(&self) -> usize {
self.inner.dimensions()
}
}
#[must_use]
pub fn truncate_matryoshka(embedding: &[f32], target_dims: usize) -> Option<Vec<f32>> {
if embedding.len() < target_dims {
return None;
}
let truncated = &embedding[..target_dims];
let norm = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
Some(truncated.iter().map(|x| x / norm).collect())
} else {
Some(truncated.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_estimate_counter() {
let c = CharEstimateCounter;
assert_eq!(c.count_tokens(""), 0);
assert_eq!(c.count_tokens("hi"), 1);
assert_eq!(c.count_tokens("hello world"), 3); }
#[test]
fn char_estimate_batch() {
let c = CharEstimateCounter;
let counts = c.count_tokens_batch(&["a", "abcdefgh"]);
assert_eq!(counts, vec![1, 2]);
}
#[test]
fn noop_reranker_returns_descending() {
let r = NoopReranker;
let docs = ["alpha", "beta", "gamma"];
let results = tokio::runtime::Runtime::new()
.unwrap()
.block_on(r.rerank("q", &docs, 2))
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].index, 0);
assert_eq!(results[1].index, 1);
assert!(results[0].score >= results[1].score);
}
#[test]
fn matryoshka_truncate() {
let emb = vec![3.0, 4.0, 0.0, 0.0];
let t = truncate_matryoshka(&emb, 2).unwrap();
assert_eq!(t.len(), 2);
let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn matryoshka_too_short() {
assert!(truncate_matryoshka(&[1.0, 2.0], 5).is_none());
}
struct StubEmbedder {
dim: usize,
id: &'static str,
}
#[async_trait]
impl Embedder for StubEmbedder {
async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
Ok(texts
.iter()
.map(|t| Embedding {
vector: vec![t.len() as f32; self.dim],
model_id: self.id.to_string(),
})
.collect())
}
fn dimensions(&self) -> usize {
self.dim
}
fn model_id(&self) -> &str {
self.id
}
fn max_input_tokens(&self) -> usize {
8192
}
}
#[tokio::test]
async fn embedder_adapter_delegates_embed_source() {
let adapter = EmbedderAdapter::new(StubEmbedder {
dim: 4,
id: "stub-v1",
});
let result = adapter.embed_source(&["hello", "world"]).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].vector.len(), 4);
assert_eq!(result[0].vector, vec![5.0; 4]);
assert_eq!(result[1].vector, vec![5.0; 4]);
}
#[tokio::test]
async fn embedder_adapter_name_and_dims() {
let adapter = EmbedderAdapter::new(StubEmbedder {
dim: 128,
id: "my-model",
});
assert_eq!(adapter.name(), "my-model");
assert_eq!(adapter.dims(), 128);
}
#[tokio::test]
async fn default_embed_query_delegates_to_embed_source() {
let adapter = EmbedderAdapter::new(StubEmbedder { dim: 3, id: "sym" });
let source = adapter.embed_source(&["test"]).await.unwrap();
let query = adapter.embed_query(&["test"]).await.unwrap();
assert_eq!(source, query);
}
struct AsymStub;
#[async_trait]
impl AsymmetricEmbedder for AsymStub {
async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
Ok(texts
.iter()
.map(|_| Embedding {
vector: vec![1.0, 0.0, 0.0],
model_id: "asym".to_string(),
})
.collect())
}
async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
Ok(texts
.iter()
.map(|_| Embedding {
vector: vec![0.0, 1.0, 0.0],
model_id: "asym".to_string(),
})
.collect())
}
fn name(&self) -> &str {
"asym"
}
fn dims(&self) -> usize {
3
}
}
#[tokio::test]
async fn asymmetric_embedder_returns_different_vectors() {
let e = AsymStub;
let source = e.embed_source(&["hello"]).await.unwrap();
let query = e.embed_query(&["hello"]).await.unwrap();
assert_ne!(source[0].vector, query[0].vector);
assert_eq!(source[0].vector, vec![1.0, 0.0, 0.0]);
assert_eq!(query[0].vector, vec![0.0, 1.0, 0.0]);
}
}