1use crate::chain::Chain;
4use crate::errors::{MarkovError, Result};
5use crate::text::Text;
6use fxhash::FxHashMap;
7
8pub fn combine_chains(models: Vec<&Chain>, weights: Option<Vec<f64>>) -> Result<Chain> {
17 if models.is_empty() {
18 return Err(MarkovError::CombineError("No models provided".to_string()));
19 }
20
21 let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
22
23 if models.len() != weights.len() {
24 return Err(MarkovError::CombineError(
25 "Models and weights lengths must be equal".to_string(),
26 ));
27 }
28
29 let state_size = models[0].state_size();
31 for model in &models[1..] {
32 if model.state_size() != state_size {
33 return Err(MarkovError::CombineError(
34 "All models must have the same state size".to_string(),
35 ));
36 }
37 }
38
39 for model in &models {
41 if model.is_compiled() {
42 return Err(MarkovError::CombineError(
43 "Cannot combine compiled models".to_string(),
44 ));
45 }
46 }
47
48 let mut combined: FxHashMap<Vec<String>, FxHashMap<String, usize>> = FxHashMap::default();
50
51 for (model, &weight) in models.iter().zip(weights.iter()) {
52 for (state, options) in model.model() {
53 let entry = combined.entry(state.clone()).or_default();
54 for (next_word, &count) in options {
55 let prev_count = entry.get(next_word).unwrap_or(&0);
56 let new_count = *prev_count + (count as f64 * weight).round() as usize;
57 entry.insert(next_word.clone(), new_count);
58 }
59 }
60 }
61
62 Ok(Chain::from_combined_model(combined, state_size))
64}
65
66pub fn combine_texts(models: Vec<&Text>, weights: Option<Vec<f64>>) -> Result<Text> {
68 if models.is_empty() {
69 return Err(MarkovError::CombineError("No models provided".to_string()));
70 }
71
72 let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
73
74 if models.len() != weights.len() {
75 return Err(MarkovError::CombineError(
76 "Models and weights lengths must be equal".to_string(),
77 ));
78 }
79
80 let state_size = models[0].state_size();
82 for model in &models[1..] {
83 if model.state_size() != state_size {
84 return Err(MarkovError::CombineError(
85 "All models must have the same state size".to_string(),
86 ));
87 }
88 }
89
90 for model in &models {
92 if model.chain().is_compiled() {
93 return Err(MarkovError::CombineError(
94 "Cannot combine compiled models".to_string(),
95 ));
96 }
97 }
98
99 let chains: Vec<&Chain> = models.iter().map(|m| m.chain()).collect();
101 let combined_chain = combine_chains(chains, Some(weights.clone()))?;
102
103 let combined_sentences = if models.iter().any(|m| m.retain_original()) {
105 let mut all_sentences = Vec::new();
106 for model in models {
107 if model.retain_original() {
108 if let Some(sentences) = model.parsed_sentences() {
109 all_sentences.extend(sentences.iter().cloned());
110 }
111 }
112 }
113 Some(all_sentences)
114 } else {
115 None
116 };
117
118 Ok(Text::from_chain(
119 combined_chain,
120 combined_sentences.clone(),
121 combined_sentences.is_some(),
122 ))
123}
124
125pub enum ModelRef<'a> {
127 Chain(&'a Chain),
128 Text(&'a Text),
129}
130
131impl<'a> ModelRef<'a> {
132 fn chain(&self) -> &Chain {
133 match self {
134 ModelRef::Chain(c) => c,
135 ModelRef::Text(t) => t.chain(),
136 }
137 }
138}
139
140pub fn combine_models(models: Vec<ModelRef>, weights: Option<Vec<f64>>) -> Result<CombinedResult> {
142 if models.is_empty() {
143 return Err(MarkovError::CombineError("No models provided".to_string()));
144 }
145
146 let weights = weights.unwrap_or_else(|| vec![1.0; models.len()]);
147
148 if models.len() != weights.len() {
149 return Err(MarkovError::CombineError(
150 "Models and weights lengths must be equal".to_string(),
151 ));
152 }
153
154 let state_size = models[0].chain().state_size();
156
157 for model in &models[1..] {
158 let model_state_size = model.chain().state_size();
159 if model_state_size != state_size {
160 return Err(MarkovError::CombineError(
161 "All models must have the same state size".to_string(),
162 ));
163 }
164 }
165
166 let first_is_chain = matches!(models[0], ModelRef::Chain(_));
168 for model in &models[1..] {
169 let is_chain = matches!(model, ModelRef::Chain(_));
170 if is_chain != first_is_chain {
171 return Err(MarkovError::CombineError(
172 "All models must be of the same type".to_string(),
173 ));
174 }
175 }
176
177 for model in &models {
179 let is_compiled = model.chain().is_compiled();
180 if is_compiled {
181 return Err(MarkovError::CombineError(
182 "Cannot combine compiled models".to_string(),
183 ));
184 }
185 }
186
187 if first_is_chain {
189 let chains: Vec<&Chain> = models
190 .iter()
191 .filter_map(|m| match m {
192 ModelRef::Chain(c) => Some(*c),
193 _ => None,
194 })
195 .collect();
196 Ok(CombinedResult::Chain(combine_chains(
197 chains,
198 Some(weights),
199 )?))
200 } else {
201 let texts: Vec<&Text> = models
202 .iter()
203 .filter_map(|m| match m {
204 ModelRef::Text(t) => Some(*t),
205 _ => None,
206 })
207 .collect();
208 Ok(CombinedResult::Text(combine_texts(texts, Some(weights))?))
209 }
210}
211
212pub enum CombinedResult {
214 Chain(Chain),
215 Text(Text),
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_combine_chains_equal_weights() {
224 let corpus1 = vec![vec!["hello".to_string(), "world".to_string()]];
225 let corpus2 = vec![vec!["hello".to_string(), "rust".to_string()]];
226 let chain1 = Chain::new(&corpus1, 1);
227 let chain2 = Chain::new(&corpus2, 1);
228
229 let combined = combine_chains(vec![&chain1, &chain2], None).unwrap();
230 assert_eq!(combined.state_size(), 1);
231 }
232
233 #[test]
234 fn test_combine_texts() {
235 let text1 = "Hello world. Goodbye world.";
236 let text2 = "Hello rust. Goodbye rust.";
237 let model1 = Text::new(text1, 2, true, true, None).unwrap();
238 let model2 = Text::new(text2, 2, true, true, None).unwrap();
239
240 let combined = combine_texts(vec![&model1, &model2], None).unwrap();
241 assert_eq!(combined.state_size(), 2);
242 }
243
244 #[test]
245 fn test_combine_mismatched_state_sizes() {
246 let corpus1 = vec![vec!["hello".to_string()]];
247 let corpus2 = vec![vec!["hello".to_string(), "world".to_string()]];
248 let chain1 = Chain::new(&corpus1, 1);
249 let chain2 = Chain::new(&corpus2, 2);
250
251 let result = combine_chains(vec![&chain1, &chain2], None);
252 assert!(result.is_err());
253 }
254}