1use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use crate::error::HirnResult;
11
12#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
14pub struct Embedding {
15 pub vector: Vec<f32>,
17 pub model_id: String,
19}
20
21#[derive(Debug, Clone, PartialEq)]
23pub struct MultivectorEmbedding {
24 pub vectors: Vec<Vec<f32>>,
26 pub model_id: String,
28}
29
30#[derive(Debug, Clone, PartialEq)]
32pub struct RerankResult {
33 pub index: usize,
35 pub score: f32,
37}
38
39#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct ExtractedEntity {
42 pub name: String,
43 pub entity_type: String,
44 pub confidence: f32,
45}
46
47#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
49pub struct ExtractedRelation {
50 pub source: String,
51 pub target: String,
52 pub relation_type: String,
53 pub weight: f32,
54}
55
56#[async_trait]
64pub trait Embedder: Send + Sync {
65 async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
71
72 fn dimensions(&self) -> usize;
74
75 fn model_id(&self) -> &str;
77
78 fn max_input_tokens(&self) -> usize;
80
81 async fn embed_multivec(&self, _texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
87 Err(crate::error::HirnError::InvalidInput(
88 "this embedder does not support multivector embeddings".into(),
89 ))
90 }
91
92 fn supports_multivec(&self) -> bool {
94 false
95 }
96}
97
98#[async_trait]
99impl<T: Embedder + ?Sized> Embedder for Arc<T> {
100 async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
101 self.as_ref().embed(texts).await
102 }
103
104 fn dimensions(&self) -> usize {
105 self.as_ref().dimensions()
106 }
107
108 fn model_id(&self) -> &str {
109 self.as_ref().model_id()
110 }
111
112 fn max_input_tokens(&self) -> usize {
113 self.as_ref().max_input_tokens()
114 }
115
116 async fn embed_multivec(&self, texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
117 self.as_ref().embed_multivec(texts).await
118 }
119
120 fn supports_multivec(&self) -> bool {
121 self.as_ref().supports_multivec()
122 }
123}
124
125pub trait TokenCounter: Send + Sync {
132 fn count_tokens(&self, text: &str) -> usize;
134
135 fn count_tokens_batch(&self, texts: &[&str]) -> Vec<usize> {
137 texts.iter().map(|t| self.count_tokens(t)).collect()
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
143pub struct CharEstimateCounter;
144
145impl TokenCounter for CharEstimateCounter {
146 fn count_tokens(&self, text: &str) -> usize {
147 text.len().div_ceil(4)
148 }
149}
150
151#[async_trait]
155pub trait Reranker: Send + Sync {
156 async fn rerank(
159 &self,
160 query: &str,
161 documents: &[&str],
162 top_k: usize,
163 ) -> HirnResult<Vec<RerankResult>>;
164}
165
166#[derive(Debug, Clone, Copy)]
168pub struct NoopReranker;
169
170#[async_trait]
171impl Reranker for NoopReranker {
172 async fn rerank(
173 &self,
174 _query: &str,
175 documents: &[&str],
176 top_k: usize,
177 ) -> HirnResult<Vec<RerankResult>> {
178 Ok(documents
179 .iter()
180 .enumerate()
181 .take(top_k)
182 .map(|(i, _)| RerankResult {
183 index: i,
184 score: 1.0 - (i as f32 / documents.len().max(1) as f32),
185 })
186 .collect())
187 }
188}
189
190#[async_trait]
194pub trait EntityExtractor: Send + Sync {
195 async fn extract_entities(
197 &self,
198 text: &str,
199 entity_types: &[&str],
200 ) -> HirnResult<Vec<ExtractedEntity>>;
201
202 async fn extract_relations(
204 &self,
205 text: &str,
206 entities: &[ExtractedEntity],
207 ) -> HirnResult<Vec<ExtractedRelation>>;
208}
209
210#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
214pub struct ChatMessage {
215 pub role: String,
216 pub content: String,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Default)]
221pub enum ResponseFormat {
222 #[default]
224 Text,
225 JsonObject,
227 JsonSchema(String),
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
233pub struct TokenUsage {
234 pub prompt_tokens: u32,
235 pub completion_tokens: u32,
236}
237
238impl TokenUsage {
239 pub const fn total(&self) -> u32 {
241 self.prompt_tokens + self.completion_tokens
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
247pub struct LlmResponse {
248 pub content: String,
249 pub usage: Option<TokenUsage>,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct LlmChunk {
255 pub delta: String,
257 pub usage: Option<TokenUsage>,
259}
260
261#[derive(Debug, Clone, PartialEq)]
263pub struct LlmOptions {
264 pub model_override: Option<String>,
265 pub temperature: f32,
266 pub max_tokens: u32,
267 pub response_format: ResponseFormat,
268}
269
270impl Default for LlmOptions {
271 fn default() -> Self {
272 Self {
273 model_override: None,
274 temperature: 0.0,
275 max_tokens: 1024,
276 response_format: ResponseFormat::Text,
277 }
278 }
279}
280
281pub type LlmStream = std::pin::Pin<Box<dyn futures::Stream<Item = HirnResult<LlmChunk>> + Send>>;
283
284#[async_trait]
286pub trait LlmProvider: Send + Sync {
287 async fn generate_text(
289 &self,
290 messages: &[ChatMessage],
291 options: &LlmOptions,
292 ) -> HirnResult<String>;
293
294 async fn generate(
298 &self,
299 messages: &[ChatMessage],
300 options: &LlmOptions,
301 ) -> HirnResult<LlmResponse> {
302 let content = self.generate_text(messages, options).await?;
303 Ok(LlmResponse {
304 content,
305 usage: None,
306 })
307 }
308
309 async fn generate_stream(
313 &self,
314 messages: &[ChatMessage],
315 options: &LlmOptions,
316 ) -> HirnResult<LlmStream> {
317 let text = self.generate_text(messages, options).await?;
318 let chunk = LlmChunk {
319 delta: text,
320 usage: None,
321 };
322 Ok(Box::pin(futures::stream::once(async { Ok(chunk) })))
323 }
324
325 fn model_id(&self) -> &str;
327}
328
329#[async_trait]
339pub trait AsymmetricEmbedder: Send + Sync {
340 async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
342
343 async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
347 self.embed_source(texts).await
348 }
349
350 fn name(&self) -> &str;
352
353 fn dims(&self) -> usize;
355}
356
357pub struct EmbedderAdapter<E: Embedder> {
362 inner: E,
363}
364
365impl<E: Embedder> EmbedderAdapter<E> {
366 pub fn new(inner: E) -> Self {
368 Self { inner }
369 }
370}
371
372#[async_trait]
373impl<E: Embedder> AsymmetricEmbedder for EmbedderAdapter<E> {
374 async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
375 self.inner.embed(texts).await
376 }
377
378 fn name(&self) -> &str {
379 self.inner.model_id()
380 }
381
382 fn dims(&self) -> usize {
383 self.inner.dimensions()
384 }
385}
386
387#[must_use]
392pub fn truncate_matryoshka(embedding: &[f32], target_dims: usize) -> Option<Vec<f32>> {
393 if embedding.len() < target_dims {
394 return None;
395 }
396 let truncated = &embedding[..target_dims];
397 let norm = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
398 if norm > 0.0 {
399 Some(truncated.iter().map(|x| x / norm).collect())
400 } else {
401 Some(truncated.to_vec())
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn char_estimate_counter() {
411 let c = CharEstimateCounter;
412 assert_eq!(c.count_tokens(""), 0);
413 assert_eq!(c.count_tokens("hi"), 1);
414 assert_eq!(c.count_tokens("hello world"), 3); }
416
417 #[test]
418 fn char_estimate_batch() {
419 let c = CharEstimateCounter;
420 let counts = c.count_tokens_batch(&["a", "abcdefgh"]);
421 assert_eq!(counts, vec![1, 2]);
422 }
423
424 #[test]
425 fn noop_reranker_returns_descending() {
426 let r = NoopReranker;
427 let docs = ["alpha", "beta", "gamma"];
428 let results = tokio::runtime::Runtime::new()
429 .unwrap()
430 .block_on(r.rerank("q", &docs, 2))
431 .unwrap();
432 assert_eq!(results.len(), 2);
433 assert_eq!(results[0].index, 0);
434 assert_eq!(results[1].index, 1);
435 assert!(results[0].score >= results[1].score);
436 }
437
438 #[test]
439 fn matryoshka_truncate() {
440 let emb = vec![3.0, 4.0, 0.0, 0.0];
441 let t = truncate_matryoshka(&emb, 2).unwrap();
442 assert_eq!(t.len(), 2);
443 let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
444 assert!((norm - 1.0).abs() < 1e-5);
445 }
446
447 #[test]
448 fn matryoshka_too_short() {
449 assert!(truncate_matryoshka(&[1.0, 2.0], 5).is_none());
450 }
451
452 struct StubEmbedder {
456 dim: usize,
457 id: &'static str,
458 }
459
460 #[async_trait]
461 impl Embedder for StubEmbedder {
462 async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
463 Ok(texts
464 .iter()
465 .map(|t| Embedding {
466 vector: vec![t.len() as f32; self.dim],
467 model_id: self.id.to_string(),
468 })
469 .collect())
470 }
471 fn dimensions(&self) -> usize {
472 self.dim
473 }
474 fn model_id(&self) -> &str {
475 self.id
476 }
477 fn max_input_tokens(&self) -> usize {
478 8192
479 }
480 }
481
482 #[tokio::test]
483 async fn embedder_adapter_delegates_embed_source() {
484 let adapter = EmbedderAdapter::new(StubEmbedder {
485 dim: 4,
486 id: "stub-v1",
487 });
488 let result = adapter.embed_source(&["hello", "world"]).await.unwrap();
489 assert_eq!(result.len(), 2);
490 assert_eq!(result[0].vector.len(), 4);
491 assert_eq!(result[0].vector, vec![5.0; 4]);
493 assert_eq!(result[1].vector, vec![5.0; 4]);
494 }
495
496 #[tokio::test]
497 async fn embedder_adapter_name_and_dims() {
498 let adapter = EmbedderAdapter::new(StubEmbedder {
499 dim: 128,
500 id: "my-model",
501 });
502 assert_eq!(adapter.name(), "my-model");
503 assert_eq!(adapter.dims(), 128);
504 }
505
506 #[tokio::test]
507 async fn default_embed_query_delegates_to_embed_source() {
508 let adapter = EmbedderAdapter::new(StubEmbedder { dim: 3, id: "sym" });
510 let source = adapter.embed_source(&["test"]).await.unwrap();
511 let query = adapter.embed_query(&["test"]).await.unwrap();
512 assert_eq!(source, query);
513 }
514
515 struct AsymStub;
517
518 #[async_trait]
519 impl AsymmetricEmbedder for AsymStub {
520 async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
521 Ok(texts
522 .iter()
523 .map(|_| Embedding {
524 vector: vec![1.0, 0.0, 0.0],
525 model_id: "asym".to_string(),
526 })
527 .collect())
528 }
529 async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
530 Ok(texts
531 .iter()
532 .map(|_| Embedding {
533 vector: vec![0.0, 1.0, 0.0],
534 model_id: "asym".to_string(),
535 })
536 .collect())
537 }
538 fn name(&self) -> &str {
539 "asym"
540 }
541 fn dims(&self) -> usize {
542 3
543 }
544 }
545
546 #[tokio::test]
547 async fn asymmetric_embedder_returns_different_vectors() {
548 let e = AsymStub;
549 let source = e.embed_source(&["hello"]).await.unwrap();
550 let query = e.embed_query(&["hello"]).await.unwrap();
551 assert_ne!(source[0].vector, query[0].vector);
552 assert_eq!(source[0].vector, vec![1.0, 0.0, 0.0]);
553 assert_eq!(query[0].vector, vec![0.0, 1.0, 0.0]);
554 }
555}