1use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::concurrency::{
13 ConcurrencyDecision, ConcurrencyPolicy, DefaultConcurrencyPolicy, RunningState,
14};
15use super::interaction::{Interaction, InteractionId, InteractionState};
16use super::thread::{Thread, ThreadId};
17use orchestral_core::store::{BroadcastEventBus, Event, EventBus, EventStore};
18use orchestral_core::types::TaskId;
19
20#[derive(Debug, Clone)]
22pub struct ThreadRuntimeConfig {
23 pub max_interactions_per_thread: usize,
25 pub auto_cleanup: bool,
27}
28
29impl Default for ThreadRuntimeConfig {
30 fn default() -> Self {
31 Self {
32 max_interactions_per_thread: 10,
33 auto_cleanup: true,
34 }
35 }
36}
37
38pub struct ThreadRuntime {
40 pub thread: RwLock<Thread>,
42 pub interactions: RwLock<HashMap<InteractionId, Interaction>>,
44 pub concurrency_policy: Arc<dyn ConcurrencyPolicy>,
46 pub event_store: Arc<dyn EventStore>,
48 pub event_bus: Arc<dyn EventBus>,
50 pub config: ThreadRuntimeConfig,
52}
53
54impl ThreadRuntime {
55 pub fn new(thread: Thread, event_store: Arc<dyn EventStore>) -> Self {
57 Self::new_with_bus(thread, event_store, Arc::new(BroadcastEventBus::default()))
58 }
59
60 pub fn new_with_bus(
62 thread: Thread,
63 event_store: Arc<dyn EventStore>,
64 event_bus: Arc<dyn EventBus>,
65 ) -> Self {
66 Self {
67 thread: RwLock::new(thread),
68 interactions: RwLock::new(HashMap::new()),
69 concurrency_policy: Arc::new(DefaultConcurrencyPolicy),
70 event_store,
71 event_bus,
72 config: ThreadRuntimeConfig::default(),
73 }
74 }
75
76 pub fn with_policy(
78 thread: Thread,
79 event_store: Arc<dyn EventStore>,
80 policy: Arc<dyn ConcurrencyPolicy>,
81 ) -> Self {
82 Self::with_policy_and_bus(
83 thread,
84 event_store,
85 policy,
86 Arc::new(BroadcastEventBus::default()),
87 )
88 }
89
90 pub fn with_policy_and_bus(
92 thread: Thread,
93 event_store: Arc<dyn EventStore>,
94 policy: Arc<dyn ConcurrencyPolicy>,
95 event_bus: Arc<dyn EventBus>,
96 ) -> Self {
97 Self {
98 thread: RwLock::new(thread),
99 interactions: RwLock::new(HashMap::new()),
100 concurrency_policy: policy,
101 event_store,
102 event_bus,
103 config: ThreadRuntimeConfig::default(),
104 }
105 }
106
107 pub fn with_config(
109 thread: Thread,
110 event_store: Arc<dyn EventStore>,
111 config: ThreadRuntimeConfig,
112 ) -> Self {
113 Self::with_config_and_bus(
114 thread,
115 event_store,
116 config,
117 Arc::new(BroadcastEventBus::default()),
118 )
119 }
120
121 pub fn with_config_and_bus(
123 thread: Thread,
124 event_store: Arc<dyn EventStore>,
125 config: ThreadRuntimeConfig,
126 event_bus: Arc<dyn EventBus>,
127 ) -> Self {
128 Self {
129 thread: RwLock::new(thread),
130 interactions: RwLock::new(HashMap::new()),
131 concurrency_policy: Arc::new(DefaultConcurrencyPolicy),
132 event_store,
133 event_bus,
134 config,
135 }
136 }
137
138 pub fn with_policy_and_config(
140 thread: Thread,
141 event_store: Arc<dyn EventStore>,
142 policy: Arc<dyn ConcurrencyPolicy>,
143 config: ThreadRuntimeConfig,
144 ) -> Self {
145 Self::with_policy_config_and_bus(
146 thread,
147 event_store,
148 policy,
149 config,
150 Arc::new(BroadcastEventBus::default()),
151 )
152 }
153
154 pub fn with_policy_config_and_bus(
156 thread: Thread,
157 event_store: Arc<dyn EventStore>,
158 policy: Arc<dyn ConcurrencyPolicy>,
159 config: ThreadRuntimeConfig,
160 event_bus: Arc<dyn EventBus>,
161 ) -> Self {
162 Self {
163 thread: RwLock::new(thread),
164 interactions: RwLock::new(HashMap::new()),
165 concurrency_policy: policy,
166 event_store,
167 event_bus,
168 config,
169 }
170 }
171
172 pub async fn thread_id(&self) -> ThreadId {
174 self.thread.read().await.id.clone()
175 }
176
177 pub async fn running_state(&self) -> RunningState {
179 let interactions = self.interactions.read().await;
180 let active_count = interactions
181 .values()
182 .filter(|i| !i.state.is_terminal())
183 .count();
184 let is_processing = interactions
185 .values()
186 .any(|i| i.state == InteractionState::Active);
187 let is_waiting_user = interactions
188 .values()
189 .any(|i| i.state == InteractionState::WaitingUser);
190 let is_waiting_event = interactions
191 .values()
192 .any(|i| i.state == InteractionState::WaitingEvent);
193
194 RunningState {
195 active_count,
196 is_processing,
197 is_waiting_user,
198 is_waiting_event,
199 }
200 }
201
202 pub async fn handle_event(&self, event: Event) -> Result<HandleEventResult, RuntimeError> {
204 self.validate_event(&event).await?;
205
206 let running_state = self.running_state().await;
208
209 let decision = self.concurrency_policy.decide(&running_state, &event);
211
212 match decision {
214 ConcurrencyDecision::InterruptAndStartNew => {
215 self.cancel_all_active().await;
217
218 let interaction_id = match self.create_interaction_if_allowed().await {
219 Ok(id) => id,
220 Err(reason) => {
221 self.persist_event(event).await?;
222 return Ok(HandleEventResult::Rejected { reason });
223 }
224 };
225
226 self.thread.write().await.touch();
228
229 self.persist_event(event.with_interaction_id(&interaction_id))
231 .await?;
232
233 Ok(HandleEventResult::Started { interaction_id })
234 }
235 ConcurrencyDecision::Reject { reason } => {
236 self.persist_event(event).await?;
238 Ok(HandleEventResult::Rejected { reason })
239 }
240 ConcurrencyDecision::Queue => {
241 self.persist_event(event).await?;
243 Ok(HandleEventResult::Rejected {
244 reason: "Queue policy is configured but queue execution is not implemented"
245 .to_string(),
246 })
247 }
248 ConcurrencyDecision::Parallel => {
249 let interaction_id = match self.create_interaction_if_allowed().await {
250 Ok(id) => id,
251 Err(reason) => {
252 self.persist_event(event).await?;
253 return Ok(HandleEventResult::Rejected { reason });
254 }
255 };
256
257 self.thread.write().await.touch();
259
260 self.persist_event(event.with_interaction_id(&interaction_id))
262 .await?;
263
264 Ok(HandleEventResult::Started { interaction_id })
265 }
266 ConcurrencyDecision::MergeIntoRunning => {
267 let interactions = self.interactions.read().await;
269 let active_id = interactions
270 .values()
271 .find(|i| i.state == InteractionState::Active)
272 .map(|i| i.id.clone());
273
274 if let Some(interaction_id) = active_id {
275 self.persist_event(event.with_interaction_id(&interaction_id))
277 .await?;
278 Ok(HandleEventResult::Merged { interaction_id })
279 } else {
280 drop(interactions);
282 let interaction_id = match self.create_interaction_if_allowed().await {
283 Ok(id) => id,
284 Err(reason) => {
285 self.persist_event(event).await?;
286 return Ok(HandleEventResult::Rejected { reason });
287 }
288 };
289
290 self.thread.write().await.touch();
291
292 self.persist_event(event.with_interaction_id(&interaction_id))
294 .await?;
295
296 Ok(HandleEventResult::Started { interaction_id })
297 }
298 }
299 }
300 }
301
302 async fn create_interaction_if_allowed(&self) -> Result<InteractionId, String> {
303 let thread_id = self.thread_id().await;
304 let interaction = Interaction::new(&thread_id);
305 let interaction_id = interaction.id.clone();
306
307 let mut interactions = self.interactions.write().await;
308 let active_count = interactions
309 .values()
310 .filter(|i| !i.state.is_terminal())
311 .count();
312 if active_count >= self.config.max_interactions_per_thread {
313 return Err(format!(
314 "Maximum active interactions ({}) reached",
315 self.config.max_interactions_per_thread
316 ));
317 }
318
319 interactions.insert(interaction_id.clone(), interaction);
320 Ok(interaction_id)
321 }
322
323 async fn validate_event(&self, event: &Event) -> Result<(), RuntimeError> {
324 let expected_thread_id = self.thread_id().await;
325 let got_thread_id = event.thread_id();
326 if expected_thread_id != got_thread_id {
327 return Err(RuntimeError::InvalidEvent(format!(
328 "thread_id mismatch (expected {}, got {})",
329 expected_thread_id, got_thread_id
330 )));
331 }
332
333 if !payload_is_valid(event) {
334 return Err(RuntimeError::InvalidEvent(
335 "payload must not be null for user/external events".to_string(),
336 ));
337 }
338
339 Ok(())
340 }
341
342 pub async fn cancel_all_active(&self) {
344 let mut interactions = self.interactions.write().await;
345 for interaction in interactions.values_mut() {
346 if !interaction.state.is_terminal() {
347 interaction.cancel();
348 }
349 }
350
351 if self.config.auto_cleanup {
353 interactions.retain(|_, i| !i.state.is_terminal());
354 }
355 }
356
357 pub async fn get_interaction(&self, id: &str) -> Option<Interaction> {
359 let interactions = self.interactions.read().await;
360 let key: InteractionId = id.into();
361 interactions.get(&key).cloned()
362 }
363
364 pub async fn add_task_to_interaction(
366 &self,
367 id: &str,
368 task_id: TaskId,
369 ) -> Result<(), RuntimeError> {
370 let mut interactions = self.interactions.write().await;
371 let key: InteractionId = id.into();
372 if let Some(interaction) = interactions.get_mut(&key) {
373 interaction.add_task(task_id);
374 Ok(())
375 } else {
376 Err(RuntimeError::InteractionNotFound(id.to_string()))
377 }
378 }
379
380 pub async fn find_resume_interaction(&self, event: &Event) -> Option<InteractionId> {
382 let target_state = match event {
383 Event::UserInput { .. } => InteractionState::WaitingUser,
384 Event::ExternalEvent { .. } => InteractionState::WaitingEvent,
385 _ => return None,
386 };
387
388 let interactions = self.interactions.read().await;
389 interactions
390 .values()
391 .filter(|i| i.state == target_state)
392 .max_by_key(|i| i.started_at)
393 .map(|i| i.id.clone())
394 }
395
396 pub async fn append_event_to_interaction(
398 &self,
399 interaction_id: &str,
400 event: Event,
401 ) -> Result<(), RuntimeError> {
402 self.validate_event(&event).await?;
403
404 let exists = {
405 let interactions = self.interactions.read().await;
406 let key: InteractionId = interaction_id.into();
407 interactions.contains_key(&key)
408 };
409 if !exists {
410 return Err(RuntimeError::InteractionNotFound(
411 interaction_id.to_string(),
412 ));
413 }
414
415 self.persist_event(event.with_interaction_id(interaction_id))
416 .await?;
417 self.thread.write().await.touch();
418 Ok(())
419 }
420
421 pub fn subscribe_events(&self) -> tokio::sync::broadcast::Receiver<Event> {
423 self.event_bus.subscribe()
424 }
425
426 pub async fn resume_interaction(&self, id: &str) -> Result<(), RuntimeError> {
428 let mut interactions = self.interactions.write().await;
429 let key: InteractionId = id.into();
430 let interaction = interactions
431 .get_mut(&key)
432 .ok_or_else(|| RuntimeError::InteractionNotFound(id.to_string()))?;
433 if interaction.state.is_terminal() {
434 return Err(RuntimeError::InvalidEvent(format!(
435 "interaction '{}' is terminal and cannot be resumed",
436 id
437 )));
438 }
439 interaction.resume();
440 Ok(())
441 }
442
443 pub async fn update_interaction_state(
445 &self,
446 id: &str,
447 state: InteractionState,
448 ) -> Result<(), RuntimeError> {
449 let mut interactions = self.interactions.write().await;
450 let key: InteractionId = id.into();
451 if let Some(interaction) = interactions.get_mut(&key) {
452 interaction.set_state(state);
453 Ok(())
454 } else {
455 Err(RuntimeError::InteractionNotFound(id.to_string()))
456 }
457 }
458
459 pub async fn complete_interaction(&self, id: &str) -> Result<(), RuntimeError> {
461 self.update_interaction_state(id, InteractionState::Completed)
462 .await
463 }
464
465 pub async fn fail_interaction(&self, id: &str) -> Result<(), RuntimeError> {
467 self.update_interaction_state(id, InteractionState::Failed)
468 .await
469 }
470
471 pub async fn active_interaction_ids(&self) -> Vec<InteractionId> {
473 let interactions = self.interactions.read().await;
474 interactions
475 .values()
476 .filter(|i| !i.state.is_terminal())
477 .map(|i| i.id.clone())
478 .collect()
479 }
480
481 pub async fn query_history(&self, limit: usize) -> Result<Vec<Event>, RuntimeError> {
483 let thread_id = self.thread_id().await;
484 let events = if limit == 0 {
485 self.event_store
486 .query_by_thread(thread_id.as_str())
487 .await
488 .map_err(|e| RuntimeError::StoreError(e.to_string()))?
489 } else {
490 self.event_store
491 .query_by_thread_with_limit(thread_id.as_str(), limit)
492 .await
493 .map_err(|e| RuntimeError::StoreError(e.to_string()))?
494 };
495 Ok(events)
496 }
497
498 pub async fn cleanup_completed(&self) {
500 let mut interactions = self.interactions.write().await;
501 interactions.retain(|_, i| !i.state.is_terminal());
502 }
503
504 async fn persist_event(&self, event: Event) -> Result<(), RuntimeError> {
505 self.event_store
506 .append(event.clone())
507 .await
508 .map_err(|e| RuntimeError::StoreError(e.to_string()))?;
509 self.event_bus
510 .publish(event)
511 .await
512 .map_err(|e| RuntimeError::Internal(format!("event bus publish failed: {}", e)))?;
513 Ok(())
514 }
515}
516
517fn payload_is_valid(event: &Event) -> bool {
518 match event {
519 Event::UserInput { payload, .. } | Event::ExternalEvent { payload, .. } => {
520 !payload.is_null()
521 }
522 _ => true,
523 }
524}
525
526#[derive(Debug, Clone)]
528pub enum HandleEventResult {
529 Started {
531 interaction_id: InteractionId,
533 },
534 Rejected {
536 reason: String,
538 },
539 Queued,
541 Merged {
543 interaction_id: InteractionId,
545 },
546}
547
548#[derive(Debug, thiserror::Error)]
550pub enum RuntimeError {
551 #[error("Store error: {0}")]
552 StoreError(String),
553
554 #[error("Interaction not found: {0}")]
555 InteractionNotFound(String),
556
557 #[error("Thread not found: {0}")]
558 ThreadNotFound(String),
559
560 #[error("Internal error: {0}")]
561 Internal(String),
562
563 #[error("Invalid event: {0}")]
564 InvalidEvent(String),
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::concurrency::{ParallelConcurrencyPolicy, QueueConcurrencyPolicy};
571 use chrono::{Duration, Utc};
572 use orchestral_core::store::{BroadcastEventBus, InMemoryEventStore};
573 use serde_json::json;
574
575 #[test]
576 fn test_find_resume_interaction_prefers_latest_waiting_user() {
577 tokio_test::block_on(async {
578 let thread_id = "thread-1";
579 let runtime = ThreadRuntime::new(
580 Thread::with_id(thread_id),
581 Arc::new(InMemoryEventStore::new()),
582 );
583
584 {
585 let mut interactions = runtime.interactions.write().await;
586
587 let mut older = Interaction::with_id("older", thread_id);
588 older.set_state(InteractionState::WaitingUser);
589 older.started_at = Utc::now() - Duration::seconds(10);
590 interactions.insert(older.id.clone(), older);
591
592 let mut newer = Interaction::with_id("newer", thread_id);
593 newer.set_state(InteractionState::WaitingUser);
594 newer.started_at = Utc::now();
595 interactions.insert(newer.id.clone(), newer);
596 }
597
598 let event = Event::user_input(thread_id, "ignored", json!({"message":"resume"}));
599 let found = runtime.find_resume_interaction(&event).await;
600 assert_eq!(found.as_ref().map(|id| id.as_str()), Some("newer"));
601 });
602 }
603
604 #[test]
605 fn test_append_event_to_interaction_rewrites_user_interaction_id() {
606 tokio_test::block_on(async {
607 let thread_id = "thread-1";
608 let runtime = ThreadRuntime::new(
609 Thread::with_id(thread_id),
610 Arc::new(InMemoryEventStore::new()),
611 );
612
613 {
614 let mut interactions = runtime.interactions.write().await;
615 interactions.insert("target".into(), Interaction::with_id("target", thread_id));
616 }
617
618 let event = Event::user_input(thread_id, "wrong", json!({"text":"hello"}));
619 runtime
620 .append_event_to_interaction("target", event)
621 .await
622 .unwrap();
623
624 let events = runtime.query_history(0).await.unwrap();
625 assert_eq!(events.len(), 1);
626 match &events[0] {
627 Event::UserInput { interaction_id, .. } => {
628 assert_eq!(interaction_id.as_str(), "target");
629 }
630 _ => panic!("expected user_input event"),
631 }
632 });
633 }
634
635 #[test]
636 fn test_handle_event_publishes_to_event_bus() {
637 tokio_test::block_on(async {
638 let thread_id = "thread-1";
639 let runtime = ThreadRuntime::new_with_bus(
640 Thread::with_id(thread_id),
641 Arc::new(InMemoryEventStore::new()),
642 Arc::new(BroadcastEventBus::new(16)),
643 );
644 let mut sub = runtime.subscribe_events();
645
646 let event = Event::user_input(thread_id, "cli", json!({"message":"hello"}));
647 let result = runtime.handle_event(event).await.unwrap();
648 assert!(matches!(result, HandleEventResult::Started { .. }));
649
650 let published = sub.recv().await.expect("published event");
651 match published {
652 Event::UserInput {
653 interaction_id,
654 payload,
655 ..
656 } => {
657 assert_ne!(interaction_id.as_str(), "cli");
658 assert_eq!(payload["message"], "hello");
659 }
660 _ => panic!("expected user_input event"),
661 }
662 });
663 }
664
665 #[test]
666 fn test_rejects_when_max_active_interactions_reached() {
667 tokio_test::block_on(async {
668 let thread_id = "thread-max";
669 let runtime = ThreadRuntime::with_policy_and_config(
670 Thread::with_id(thread_id),
671 Arc::new(InMemoryEventStore::new()),
672 Arc::new(ParallelConcurrencyPolicy::new(10)),
673 ThreadRuntimeConfig {
674 max_interactions_per_thread: 1,
675 auto_cleanup: false,
676 },
677 );
678
679 let first = Event::user_input(thread_id, "a", json!({"message":"first"}));
680 let first_result = runtime.handle_event(first).await.unwrap();
681 assert!(matches!(first_result, HandleEventResult::Started { .. }));
682
683 let second = Event::user_input(thread_id, "b", json!({"message":"second"}));
684 let second_result = runtime.handle_event(second).await.unwrap();
685 match second_result {
686 HandleEventResult::Rejected { reason } => {
687 assert!(reason.contains("Maximum active interactions (1) reached"));
688 }
689 other => panic!("expected rejected result, got {:?}", other),
690 }
691 });
692 }
693
694 #[test]
695 fn test_queue_policy_returns_rejected_not_queued() {
696 tokio_test::block_on(async {
697 let thread_id = "thread-queue";
698 let runtime = ThreadRuntime::with_policy(
699 Thread::with_id(thread_id),
700 Arc::new(InMemoryEventStore::new()),
701 Arc::new(QueueConcurrencyPolicy),
702 );
703
704 let first = Event::user_input(thread_id, "a", json!({"message":"first"}));
705 let first_result = runtime.handle_event(first).await.unwrap();
706 assert!(matches!(first_result, HandleEventResult::Started { .. }));
707
708 let second = Event::user_input(thread_id, "b", json!({"message":"second"}));
709 let second_result = runtime.handle_event(second).await.unwrap();
710 match second_result {
711 HandleEventResult::Rejected { reason } => {
712 assert!(reason.contains("Queue policy"));
713 }
714 other => panic!("expected rejected result, got {:?}", other),
715 }
716 });
717 }
718}