1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use std::sync::Arc;

pub use document::Document;
pub use document::Metadata;
pub use embedding::Embedding;
pub use embedding::Semantic;
pub use embedding::semantic::SemanticError;
pub use similarity::CosineSimilarity;
pub use similarity::DocumentMatch;
pub use similarity::EmbeddingMatch;
pub use similarity::RelevanceScore;
pub use similarity::Similarity;
pub use store::EmbeddingStore;
pub use store::InMemoryEmbeddingStore;

pub mod document;
pub mod embedding;
pub mod similarity;
pub mod store;

pub fn get_cosine_similarity() -> Arc<dyn Similarity> {
    Arc::new(CosineSimilarity {})
}

pub fn init_semantic(model: Vec<u8>, tokenizer_data: Vec<u8>) -> Result<Arc<Semantic>, SemanticError> {
    let result = Semantic::init_semantic(model, tokenizer_data)?;
    Ok(Arc::new(result))
}

pub fn init_semantic_with_path(model_path: &str, tokenizer_path: &str) -> Result<Arc<Semantic>, SemanticError> {
    let model = std::fs::read(model_path).map_err(|_| SemanticError::InitModelReadError)?;
    let tokenizer_data = std::fs::read(tokenizer_path).map_err(|_| SemanticError::InitTokenizerReadError)?;

    let result = Semantic::init_semantic(model, tokenizer_data)?;
    Ok(Arc::new(result))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[cfg_attr(feature = "ci", ignore)]
    fn test_init_semantic() {
        let model = std::fs::read("../model/model.onnx").unwrap();
        let tokenizer_data = std::fs::read("../model/tokenizer.json").unwrap();

        let semantic = init_semantic(model, tokenizer_data).unwrap();
        let embedding = semantic.embed("hello world").unwrap();
        assert_eq!(embedding.len(), 128);
    }

    #[test]
    #[cfg_attr(feature = "ci", ignore)]
    fn should_find_relevant() {
        let model = std::fs::read("../model/model.onnx").unwrap();
        let tokenizer_data = std::fs::read("../model/tokenizer.json").unwrap();

        let semantic = init_semantic(model, tokenizer_data).unwrap();
        let hello = "hello world";
        let pure_text_hello = semantic.embed(hello).unwrap();
        let code_hello_text = "print('hello world')";
        let code_hello = semantic.embed(code_hello_text).unwrap();

        let embedding_store = InMemoryEmbeddingStore::new();
        embedding_store.add(hello.to_string(), pure_text_hello.clone(), Document::from(hello.to_string()));
        embedding_store.add(code_hello_text.to_string(), code_hello.clone(), Document::from(code_hello_text.to_string()));

        let vec = embedding_store.find_relevant(pure_text_hello, 1, 0.0);
        assert_eq!(vec.len(), 1);
    }
}

uniffi::include_scaffolding!("inference");