1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use std::{collections::HashMap, hash::Hash};

use sllm::message::PromptMessageGroup;

//
// PromptManager
#[derive(Debug)]
pub struct PromptManager<T: Hash + Eq> {
    prompts: HashMap<String, PromptMessageGroup>,
    patterns: HashMap<T, String>,
}

impl<T: Hash + Eq> PromptManager<T> {
    pub fn new() -> Self {
        Self {
            prompts: HashMap::new(),
            patterns: HashMap::new(),
        }
    }

    fn parse_pattern<'a>(pattern: &'a str) -> impl Iterator<Item = &'a str> {
        pattern.split_whitespace()
    }

    pub fn insert_prompt(&mut self, alias: &str, prompt: PromptMessageGroup) {
        self.prompts.insert(alias.into(), prompt);
    }

    pub fn register_pattern(&mut self, key: T, pattern: &str) {
        self.patterns.insert(key, pattern.into());
    }

    pub fn get(&self, key: T) -> Vec<PromptMessageGroup> {
        self.patterns
            .get(&key)
            .into_iter()
            .flat_map(|pattern| Self::parse_pattern(pattern))
            .filter_map(|alias| self.prompts.get(alias))
            .cloned()
            .collect()
    }
}