Skip to main content

markovify_rs/
utils.rs

1//! Utility functions for combining models
2
3use crate::chain::Chain;
4use crate::errors::{MarkovError, Result};
5use crate::text::Text;
6use fxhash::FxHashMap;
7
8/// Combine multiple models into one
9///
10/// # Arguments
11/// * `models` - A list of models to combine (Chain or Text)
12/// * `weights` - Optional weights for each model (default: equal weights)
13///
14/// # Returns
15/// A combined model of the same type as the input models
16pub fn combine_chains(models: Vec<&Chain>, weights: Option<Vec<f64>>) -> Result<Chain> {
17    if models.is_empty() {
18        return Err(MarkovError::CombineError("No models provided".to_string()));
19    }
20
21    let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
22
23    if models.len() != weights.len() {
24        return Err(MarkovError::CombineError(
25            "Models and weights lengths must be equal".to_string(),
26        ));
27    }
28
29    // Check that all models have the same state size
30    let state_size = models[0].state_size();
31    for model in &models[1..] {
32        if model.state_size() != state_size {
33            return Err(MarkovError::CombineError(
34                "All models must have the same state size".to_string(),
35            ));
36        }
37    }
38
39    // Check that no model is compiled
40    for model in &models {
41        if model.is_compiled() {
42            return Err(MarkovError::CombineError(
43                "Cannot combine compiled models".to_string(),
44            ));
45        }
46    }
47
48    // Combine the models by merging their model HashMaps
49    let mut combined: FxHashMap<Vec<String>, FxHashMap<String, usize>> = FxHashMap::default();
50
51    for (model, &weight) in models.iter().zip(weights.iter()) {
52        for (state, options) in model.model() {
53            let entry = combined.entry(state.clone()).or_default();
54            for (next_word, &count) in options {
55                let prev_count = entry.get(next_word).unwrap_or(&0);
56                let new_count = *prev_count + (count as f64 * weight).round() as usize;
57                entry.insert(next_word.clone(), new_count);
58            }
59        }
60    }
61
62    // Create a new chain directly with the combined model
63    Ok(Chain::from_combined_model(combined, state_size))
64}
65
66/// Combine multiple Text models
67pub fn combine_texts(models: Vec<&Text>, weights: Option<Vec<f64>>) -> Result<Text> {
68    if models.is_empty() {
69        return Err(MarkovError::CombineError("No models provided".to_string()));
70    }
71
72    let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
73
74    if models.len() != weights.len() {
75        return Err(MarkovError::CombineError(
76            "Models and weights lengths must be equal".to_string(),
77        ));
78    }
79
80    // Check that all models have the same state size
81    let state_size = models[0].state_size();
82    for model in &models[1..] {
83        if model.state_size() != state_size {
84            return Err(MarkovError::CombineError(
85                "All models must have the same state size".to_string(),
86            ));
87        }
88    }
89
90    // Check that no model is compiled
91    for model in &models {
92        if model.chain().is_compiled() {
93            return Err(MarkovError::CombineError(
94                "Cannot combine compiled models".to_string(),
95            ));
96        }
97    }
98
99    // Combine the underlying chains
100    let chains: Vec<&Chain> = models.iter().map(|m| m.chain()).collect();
101    let combined_chain = combine_chains(chains, Some(weights.clone()))?;
102
103    // Combine parsed sentences if any model retains original
104    let combined_sentences = if models.iter().any(|m| m.retain_original()) {
105        let mut all_sentences = Vec::new();
106        for model in models {
107            if model.retain_original() {
108                if let Some(sentences) = model.parsed_sentences() {
109                    all_sentences.extend(sentences.iter().cloned());
110                }
111            }
112        }
113        Some(all_sentences)
114    } else {
115        None
116    };
117
118    Ok(Text::from_chain(
119        combined_chain,
120        combined_sentences.clone(),
121        combined_sentences.is_some(),
122    ))
123}
124
125/// Helper enum for combining different model types
126pub enum ModelRef<'a> {
127    Chain(&'a Chain),
128    Text(&'a Text),
129}
130
131impl<'a> ModelRef<'a> {
132    fn chain(&self) -> &Chain {
133        match self {
134            ModelRef::Chain(c) => c,
135            ModelRef::Text(t) => t.chain(),
136        }
137    }
138}
139
140/// Combine models of potentially different types
141pub fn combine_models(models: Vec<ModelRef>, weights: Option<Vec<f64>>) -> Result<CombinedResult> {
142    if models.is_empty() {
143        return Err(MarkovError::CombineError("No models provided".to_string()));
144    }
145
146    let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
147
148    if models.len() != weights.len() {
149        return Err(MarkovError::CombineError(
150            "Models and weights lengths must be equal".to_string(),
151        ));
152    }
153
154    // Check state sizes match
155    let state_size = models[0].chain().state_size();
156
157    for model in &models[1..] {
158        let model_state_size = model.chain().state_size();
159        if model_state_size != state_size {
160            return Err(MarkovError::CombineError(
161                "All models must have the same state size".to_string(),
162            ));
163        }
164    }
165
166    // Check types match
167    let first_is_chain = matches!(models[0], ModelRef::Chain(_));
168    for model in &models[1..] {
169        let is_chain = matches!(model, ModelRef::Chain(_));
170        if is_chain != first_is_chain {
171            return Err(MarkovError::CombineError(
172                "All models must be of the same type".to_string(),
173            ));
174        }
175    }
176
177    // Check no model is compiled
178    for model in &models {
179        let is_compiled = model.chain().is_compiled();
180        if is_compiled {
181            return Err(MarkovError::CombineError(
182                "Cannot combine compiled models".to_string(),
183            ));
184        }
185    }
186
187    // Combine based on type
188    if first_is_chain {
189        let chains: Vec<&Chain> = models
190            .iter()
191            .filter_map(|m| match m {
192                ModelRef::Chain(c) => Some(*c),
193                _ => None,
194            })
195            .collect();
196        Ok(CombinedResult::Chain(combine_chains(
197            chains,
198            Some(weights),
199        )?))
200    } else {
201        let texts: Vec<&Text> = models
202            .iter()
203            .filter_map(|m| match m {
204                ModelRef::Text(t) => Some(*t),
205                _ => None,
206            })
207            .collect();
208        Ok(CombinedResult::Text(combine_texts(texts, Some(weights))?))
209    }
210}
211
212/// Result of combining models
213pub enum CombinedResult {
214    Chain(Chain),
215    Text(Text),
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_combine_chains_equal_weights() {
224        let corpus1 = vec![vec!["hello".to_string(), "world".to_string()]];
225        let corpus2 = vec![vec!["hello".to_string(), "rust".to_string()]];
226        let chain1 = Chain::new(&corpus1, 1);
227        let chain2 = Chain::new(&corpus2, 1);
228
229        let combined = combine_chains(vec![&chain1, &chain2], None).unwrap();
230        assert_eq!(combined.state_size(), 1);
231    }
232
233    #[test]
234    fn test_combine_texts() {
235        let text1 = "Hello world. Goodbye world.";
236        let text2 = "Hello rust. Goodbye rust.";
237        let model1 = Text::new(text1, 2, true, true, None).unwrap();
238        let model2 = Text::new(text2, 2, true, true, None).unwrap();
239
240        let combined = combine_texts(vec![&model1, &model2], None).unwrap();
241        assert_eq!(combined.state_size(), 2);
242    }
243
244    #[test]
245    fn test_combine_mismatched_state_sizes() {
246        let corpus1 = vec![vec!["hello".to_string()]];
247        let corpus2 = vec![vec!["hello".to_string(), "world".to_string()]];
248        let chain1 = Chain::new(&corpus1, 1);
249        let chain2 = Chain::new(&corpus2, 2);
250
251        let result = combine_chains(vec![&chain1, &chain2], None);
252        assert!(result.is_err());
253    }
254}