Skip to main content

erio_embedding/
task.rs

1//! Task type definitions and prompt formatting for embedding models.
2
3/// The type of embedding task, used to format input prompts.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
5pub enum TaskType {
6    /// Search result retrieval.
7    #[default]
8    SearchResult,
9    /// Search query.
10    SearchQuery,
11    /// Text classification.
12    Classification,
13    /// Text clustering.
14    Clustering,
15    /// Semantic similarity.
16    SemanticSimilarity,
17    /// Fact verification.
18    FactVerification,
19    /// Code retrieval.
20    CodeRetrieval,
21}
22
23impl TaskType {
24    /// Returns the human-readable description used in prompt formatting.
25    pub fn description(self) -> &'static str {
26        match self {
27            Self::SearchResult => "search result",
28            Self::SearchQuery => "search query",
29            Self::Classification => "classification",
30            Self::Clustering => "clustering",
31            Self::SemanticSimilarity => "semantic similarity",
32            Self::FactVerification => "fact verification",
33            Self::CodeRetrieval => "code retrieval",
34        }
35    }
36}
37
38/// Formats a query text with a task-type prefix for embedding.
39pub fn format_query(text: &str, task_type: TaskType) -> String {
40    format!("task: {} | query: {text}", task_type.description())
41}
42
43/// Formats a document text with an optional title for embedding.
44pub fn format_document(text: &str, title: Option<&str>) -> String {
45    let title = title.unwrap_or("none");
46    format!("title: {title} | text: {text}")
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    // === TaskType::description() tests ===
54
55    #[test]
56    fn search_result_description() {
57        assert_eq!(TaskType::SearchResult.description(), "search result");
58    }
59
60    #[test]
61    fn search_query_description() {
62        assert_eq!(TaskType::SearchQuery.description(), "search query");
63    }
64
65    #[test]
66    fn classification_description() {
67        assert_eq!(TaskType::Classification.description(), "classification");
68    }
69
70    #[test]
71    fn clustering_description() {
72        assert_eq!(TaskType::Clustering.description(), "clustering");
73    }
74
75    #[test]
76    fn semantic_similarity_description() {
77        assert_eq!(
78            TaskType::SemanticSimilarity.description(),
79            "semantic similarity"
80        );
81    }
82
83    #[test]
84    fn fact_verification_description() {
85        assert_eq!(
86            TaskType::FactVerification.description(),
87            "fact verification"
88        );
89    }
90
91    #[test]
92    fn code_retrieval_description() {
93        assert_eq!(TaskType::CodeRetrieval.description(), "code retrieval");
94    }
95
96    // === Default ===
97
98    #[test]
99    fn default_is_search_result() {
100        assert_eq!(TaskType::default(), TaskType::SearchResult);
101    }
102
103    // === format_query() tests ===
104
105    #[test]
106    fn format_query_with_search_result() {
107        let result = format_query("what is rust", TaskType::SearchResult);
108        assert_eq!(result, "task: search result | query: what is rust");
109    }
110
111    #[test]
112    fn format_query_with_classification() {
113        let result = format_query("hello world", TaskType::Classification);
114        assert_eq!(result, "task: classification | query: hello world");
115    }
116
117    // === format_document() tests ===
118
119    #[test]
120    fn format_document_without_title() {
121        let result = format_document("some document text", None);
122        assert_eq!(result, "title: none | text: some document text");
123    }
124
125    #[test]
126    fn format_document_with_title() {
127        let result = format_document("some document text", Some("My Title"));
128        assert_eq!(result, "title: My Title | text: some document text");
129    }
130}