1use crate::error::{Error, Result};
7use peat_schema::command::v1::{CommandPriority, ConflictPolicy, HierarchicalCommand};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum ConflictResult {
15 NoConflict,
17 Conflict(Vec<HierarchicalCommand>),
19}
20
21pub struct ConflictResolver {
26 target_commands: Arc<RwLock<HashMap<String, Vec<HierarchicalCommand>>>>,
29}
30
31impl ConflictResolver {
32 pub fn new() -> Self {
34 Self {
35 target_commands: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38
39 pub async fn check_conflict(&self, command: &HierarchicalCommand) -> ConflictResult {
44 let target_ids = self.extract_target_ids(command);
45 let commands = self.target_commands.read().await;
46
47 let mut conflicting = Vec::new();
48
49 for target_id in target_ids {
50 if let Some(existing) = commands.get(&target_id) {
51 conflicting.extend(existing.clone());
52 }
53 }
54
55 if conflicting.is_empty() {
56 ConflictResult::NoConflict
57 } else {
58 ConflictResult::Conflict(conflicting)
59 }
60 }
61
62 pub fn resolve(
67 &self,
68 commands: Vec<HierarchicalCommand>,
69 policy: ConflictPolicy,
70 ) -> Result<HierarchicalCommand> {
71 if commands.is_empty() {
72 return Err(Error::InvalidInput(
73 "Cannot resolve conflict with empty command list".to_string(),
74 ));
75 }
76
77 if commands.len() == 1 {
78 return Ok(commands.into_iter().next().expect("len checked to be 1"));
79 }
80
81 match policy {
82 ConflictPolicy::LastWriteWins => self.resolve_last_write_wins(commands),
83 ConflictPolicy::HighestPriorityWins => self.resolve_highest_priority_wins(commands),
84 ConflictPolicy::HighestAuthorityWins => self.resolve_highest_authority_wins(commands),
85 ConflictPolicy::MergeCompatible => self.resolve_merge_compatible(commands),
86 ConflictPolicy::RejectConflict => Err(Error::ConflictDetected(
87 "Conflict policy REJECT_CONFLICT: rejecting new command".to_string(),
88 )),
89 ConflictPolicy::Unspecified => Err(Error::InvalidInput(
90 "Conflict policy must be specified".to_string(),
91 )),
92 }
93 }
94
95 pub async fn register_command(&self, command: &HierarchicalCommand) -> Result<()> {
97 let target_ids = self.extract_target_ids(command);
98 let mut commands = self.target_commands.write().await;
99
100 for target_id in target_ids {
101 commands.entry(target_id).or_default().push(command.clone());
102 }
103
104 Ok(())
105 }
106
107 pub async fn unregister_command(&self, command_id: &str) -> Result<()> {
109 let mut commands = self.target_commands.write().await;
110
111 for (_, cmd_list) in commands.iter_mut() {
113 cmd_list.retain(|cmd| cmd.command_id != command_id);
114 }
115
116 commands.retain(|_, cmd_list| !cmd_list.is_empty());
118
119 Ok(())
120 }
121
122 fn extract_target_ids(&self, command: &HierarchicalCommand) -> Vec<String> {
124 command
125 .target
126 .as_ref()
127 .map(|t| t.target_ids.clone())
128 .unwrap_or_default()
129 }
130
131 fn resolve_last_write_wins(
135 &self,
136 mut commands: Vec<HierarchicalCommand>,
137 ) -> Result<HierarchicalCommand> {
138 commands.sort_by(|a, b| {
139 let a_time = a.issued_at.as_ref().map(|t| t.seconds).unwrap_or(0);
140 let b_time = b.issued_at.as_ref().map(|t| t.seconds).unwrap_or(0);
141 b_time.cmp(&a_time) });
143
144 Ok(commands
145 .into_iter()
146 .next()
147 .expect("commands verified non-empty at function entry"))
148 }
149
150 fn resolve_highest_priority_wins(
154 &self,
155 mut commands: Vec<HierarchicalCommand>,
156 ) -> Result<HierarchicalCommand> {
157 commands.sort_by(|a, b| {
158 let a_priority =
159 CommandPriority::try_from(a.priority).unwrap_or(CommandPriority::Routine);
160 let b_priority =
161 CommandPriority::try_from(b.priority).unwrap_or(CommandPriority::Routine);
162 b_priority.cmp(&a_priority) });
164
165 Ok(commands
166 .into_iter()
167 .next()
168 .expect("commands verified non-empty at function entry"))
169 }
170
171 fn resolve_highest_authority_wins(
179 &self,
180 mut commands: Vec<HierarchicalCommand>,
181 ) -> Result<HierarchicalCommand> {
182 commands.sort_by(|a, b| {
183 let a_authority = self.derive_authority_level(&a.originator_id);
184 let b_authority = self.derive_authority_level(&b.originator_id);
185 b_authority.cmp(&a_authority) });
187
188 Ok(commands
189 .into_iter()
190 .next()
191 .expect("commands verified non-empty at function entry"))
192 }
193
194 fn resolve_merge_compatible(
200 &self,
201 commands: Vec<HierarchicalCommand>,
202 ) -> Result<HierarchicalCommand> {
203 Ok(commands
206 .into_iter()
207 .next()
208 .expect("commands verified non-empty at function entry"))
209 }
210
211 fn derive_authority_level(&self, node_id: &str) -> u32 {
219 if node_id.starts_with("zone-") {
220 3
221 } else if node_id.starts_with("platoon-") || node_id.starts_with("squad-") {
222 2
223 } else {
224 1
225 }
226 }
227}
228
229impl Default for ConflictResolver {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use peat_schema::command::v1::{command_target::Scope, CommandTarget};
239 use peat_schema::common::v1::Timestamp;
240
241 fn create_test_command(
242 command_id: &str,
243 originator_id: &str,
244 target_id: &str,
245 priority: i32,
246 issued_at_seconds: u64,
247 ) -> HierarchicalCommand {
248 HierarchicalCommand {
249 command_id: command_id.to_string(),
250 originator_id: originator_id.to_string(),
251 target: Some(CommandTarget {
252 scope: Scope::Individual as i32,
253 target_ids: vec![target_id.to_string()],
254 }),
255 priority,
256 issued_at: Some(Timestamp {
257 seconds: issued_at_seconds,
258 nanos: 0,
259 }),
260 conflict_policy: ConflictPolicy::HighestPriorityWins as i32,
261 ..Default::default()
262 }
263 }
264
265 #[tokio::test]
266 async fn test_no_conflict_on_different_targets() {
267 let resolver = ConflictResolver::new();
268
269 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
270 resolver.register_command(&cmd1).await.unwrap();
271
272 let cmd2 = create_test_command("cmd-2", "node-2", "target-2", 3, 1001);
273 let result = resolver.check_conflict(&cmd2).await;
274
275 assert_eq!(result, ConflictResult::NoConflict);
276 }
277
278 #[tokio::test]
279 async fn test_conflict_on_same_target() {
280 let resolver = ConflictResolver::new();
281
282 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
283 resolver.register_command(&cmd1).await.unwrap();
284
285 let cmd2 = create_test_command("cmd-2", "node-2", "target-1", 3, 1001);
286 let result = resolver.check_conflict(&cmd2).await;
287
288 match result {
289 ConflictResult::Conflict(cmds) => {
290 assert_eq!(cmds.len(), 1);
291 assert_eq!(cmds[0].command_id, "cmd-1");
292 }
293 ConflictResult::NoConflict => panic!("Expected conflict"),
294 }
295 }
296
297 #[tokio::test]
298 async fn test_last_write_wins() {
299 let resolver = ConflictResolver::new();
300
301 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
302 let cmd2 = create_test_command("cmd-2", "node-2", "target-1", 3, 1001);
303 let cmd3 = create_test_command("cmd-3", "node-3", "target-1", 3, 999);
304
305 let winner = resolver
306 .resolve(vec![cmd1, cmd2, cmd3], ConflictPolicy::LastWriteWins)
307 .unwrap();
308
309 assert_eq!(winner.command_id, "cmd-2"); }
311
312 #[tokio::test]
313 async fn test_highest_priority_wins() {
314 let resolver = ConflictResolver::new();
315
316 let cmd1 = create_test_command(
317 "cmd-1",
318 "node-1",
319 "target-1",
320 CommandPriority::Routine as i32,
321 1000,
322 );
323 let cmd2 = create_test_command(
324 "cmd-2",
325 "node-2",
326 "target-1",
327 CommandPriority::Flash as i32,
328 1001,
329 );
330 let cmd3 = create_test_command(
331 "cmd-3",
332 "node-3",
333 "target-1",
334 CommandPriority::Immediate as i32,
335 999,
336 );
337
338 let winner = resolver
339 .resolve(vec![cmd1, cmd2, cmd3], ConflictPolicy::HighestPriorityWins)
340 .unwrap();
341
342 assert_eq!(winner.command_id, "cmd-2"); }
344
345 #[tokio::test]
346 async fn test_highest_authority_wins() {
347 let resolver = ConflictResolver::new();
348
349 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
350 let cmd2 = create_test_command("cmd-2", "squad-alpha", "target-1", 3, 1001);
351 let cmd3 = create_test_command("cmd-3", "zone-leader", "target-1", 3, 999);
352
353 let winner = resolver
354 .resolve(vec![cmd1, cmd2, cmd3], ConflictPolicy::HighestAuthorityWins)
355 .unwrap();
356
357 assert_eq!(winner.command_id, "cmd-3"); }
359
360 #[tokio::test]
361 async fn test_reject_conflict() {
362 let resolver = ConflictResolver::new();
363
364 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
365 let cmd2 = create_test_command("cmd-2", "node-2", "target-1", 3, 1001);
366
367 let result = resolver.resolve(vec![cmd1, cmd2], ConflictPolicy::RejectConflict);
368
369 assert!(result.is_err());
370 assert!(matches!(result, Err(Error::ConflictDetected(_))));
371 }
372
373 #[tokio::test]
374 async fn test_unregister_command() {
375 let resolver = ConflictResolver::new();
376
377 let cmd1 = create_test_command("cmd-1", "node-1", "target-1", 3, 1000);
378 resolver.register_command(&cmd1).await.unwrap();
379
380 let cmd2 = create_test_command("cmd-2", "node-2", "target-1", 3, 1001);
382 let result = resolver.check_conflict(&cmd2).await;
383 assert!(matches!(result, ConflictResult::Conflict(_)));
384
385 resolver.unregister_command("cmd-1").await.unwrap();
387
388 let result = resolver.check_conflict(&cmd2).await;
390 assert_eq!(result, ConflictResult::NoConflict);
391 }
392
393 #[tokio::test]
394 async fn test_authority_level_derivation() {
395 let resolver = ConflictResolver::new();
396
397 assert_eq!(resolver.derive_authority_level("zone-leader"), 3);
398 assert_eq!(resolver.derive_authority_level("platoon-alpha"), 2);
399 assert_eq!(resolver.derive_authority_level("squad-bravo"), 2);
400 assert_eq!(resolver.derive_authority_level("node-1"), 1);
401 }
402}