orbok_workers/
embedding.rs1use orbok_cache::{CacheService, EngineOptions, OrbokCacheNamespace};
6use orbok_core::{FileId, ModelId, OrbokError, OrbokResult};
7use orbok_db::Catalog;
8use orbok_db::repo::{
9 ChunkRepository, EmbeddingRepository, FileRepository, NewEmbedding, SourceRepository,
10};
11use orbok_extract::ExtractOutput;
12use orbok_fs::{GuardedSource, PathGuard};
13use orbok_models::{EmbeddingModel, MockEmbeddingModel};
14use std::path::Path;
15
16pub struct EmbeddingWorker<'a> {
18 catalog: &'a Catalog,
19 cache: &'a CacheService,
20 model: Box<dyn EmbeddingModel>,
21 model_id: ModelId,
22}
23
24impl<'a> EmbeddingWorker<'a> {
25 pub fn with_mock(catalog: &'a Catalog, cache: &'a CacheService) -> Self {
27 Self {
28 catalog,
29 cache,
30 model: Box::new(MockEmbeddingModel),
31 model_id: ModelId::from_string("mock_mock-v1".to_string()),
32 }
33 }
34
35
36 pub fn with_model(
40 catalog: &'a Catalog,
41 cache: &'a CacheService,
42 model: Box<dyn EmbeddingModel>,
43 model_id: ModelId,
44 ) -> Self {
45 Self { catalog, cache, model, model_id }
46 }
47
48 pub fn run(&self, file_id: &FileId) -> OrbokResult<()> {
50 let files = FileRepository::new(self.catalog);
51 let record = files
52 .get_by_id(file_id)?
53 .ok_or(OrbokError::FileNotFound)?;
54 let sources = SourceRepository::new(self.catalog);
55 let source = sources
56 .get(&record.source_id)?
57 .ok_or(OrbokError::SourceNotFound)?;
58
59 let guard = PathGuard::new(vec![GuardedSource::from_record(&source)]);
62 let validated = guard.validate(Path::new(&record.canonical_path))?;
63 let engine = self.cache.engine::<ExtractOutput>(
64 self.catalog,
65 &OrbokCacheNamespace::ExtractSegments,
66 EngineOptions::default(),
67 )?;
68 let Some(extract_output) = CacheService::get_fresh(&engine, &validated)? else {
69 return Ok(()); };
71
72 let chunks = ChunkRepository::new(self.catalog).list_for_file(file_id)?;
74 if chunks.is_empty() {
75 return Ok(());
76 }
77
78 let all_text: String = extract_output
82 .segments
83 .iter()
84 .map(|s| s.text.as_str())
85 .collect::<Vec<_>>()
86 .join("\n");
87 let texts: Vec<String> = chunks
88 .iter()
89 .map(|chunk| {
90 if let Some(heading) = &chunk.heading_path {
91 format!("{heading}\n{all_text}")
92 } else {
93 all_text.clone()
94 }
95 })
96 .collect();
97
98 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
99 let vectors = self.model.embed_batch(&text_refs)?;
100
101 let embeddings = EmbeddingRepository::new(self.catalog);
102 for (chunk, vector) in chunks.iter().zip(vectors.into_iter()) {
103 embeddings.upsert(&NewEmbedding {
104 chunk_id: chunk.chunk_id.clone(),
105 model_id: self.model_id.clone(),
106 dimension: self.model.dimension(),
107 vector,
108 })?;
109 }
110 Ok(())
111 }
112
113 pub fn model_id(&self) -> &ModelId {
114 &self.model_id
115 }
116}