Skip to main content

orchestral_runtime/
thread_runtime.rs

1//! ThreadRuntime - Thread lifecycle and interaction management
2//!
3//! ThreadRuntime manages:
4//! - Thread lifecycle
5//! - Interaction concurrency
6//! - Event routing
7
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::concurrency::{
13    ConcurrencyDecision, ConcurrencyPolicy, DefaultConcurrencyPolicy, RunningState,
14};
15use super::interaction::{Interaction, InteractionId, InteractionState};
16use super::thread::{Thread, ThreadId};
17use orchestral_core::store::{BroadcastEventBus, Event, EventBus, EventStore};
18use orchestral_core::types::TaskId;
19
20/// Configuration for ThreadRuntime
21#[derive(Debug, Clone)]
22pub struct ThreadRuntimeConfig {
23    /// Maximum number of active interactions per thread
24    pub max_interactions_per_thread: usize,
25    /// Whether to auto-cleanup completed interactions
26    pub auto_cleanup: bool,
27}
28
29impl Default for ThreadRuntimeConfig {
30    fn default() -> Self {
31        Self {
32            max_interactions_per_thread: 10,
33            auto_cleanup: true,
34        }
35    }
36}
37
38/// ThreadRuntime - manages thread lifecycle and interaction concurrency
39pub struct ThreadRuntime {
40    /// The thread being managed
41    pub thread: RwLock<Thread>,
42    /// Active interactions in this thread
43    pub interactions: RwLock<HashMap<InteractionId, Interaction>>,
44    /// Concurrency policy
45    pub concurrency_policy: Arc<dyn ConcurrencyPolicy>,
46    /// Event store
47    pub event_store: Arc<dyn EventStore>,
48    /// Realtime event bus
49    pub event_bus: Arc<dyn EventBus>,
50    /// Configuration
51    pub config: ThreadRuntimeConfig,
52}
53
54impl ThreadRuntime {
55    /// Create a new thread runtime
56    pub fn new(thread: Thread, event_store: Arc<dyn EventStore>) -> Self {
57        Self::new_with_bus(thread, event_store, Arc::new(BroadcastEventBus::default()))
58    }
59
60    /// Create a new thread runtime with custom event bus
61    pub fn new_with_bus(
62        thread: Thread,
63        event_store: Arc<dyn EventStore>,
64        event_bus: Arc<dyn EventBus>,
65    ) -> Self {
66        Self {
67            thread: RwLock::new(thread),
68            interactions: RwLock::new(HashMap::new()),
69            concurrency_policy: Arc::new(DefaultConcurrencyPolicy),
70            event_store,
71            event_bus,
72            config: ThreadRuntimeConfig::default(),
73        }
74    }
75
76    /// Create a new thread runtime with custom policy
77    pub fn with_policy(
78        thread: Thread,
79        event_store: Arc<dyn EventStore>,
80        policy: Arc<dyn ConcurrencyPolicy>,
81    ) -> Self {
82        Self::with_policy_and_bus(
83            thread,
84            event_store,
85            policy,
86            Arc::new(BroadcastEventBus::default()),
87        )
88    }
89
90    /// Create a new thread runtime with custom policy and event bus
91    pub fn with_policy_and_bus(
92        thread: Thread,
93        event_store: Arc<dyn EventStore>,
94        policy: Arc<dyn ConcurrencyPolicy>,
95        event_bus: Arc<dyn EventBus>,
96    ) -> Self {
97        Self {
98            thread: RwLock::new(thread),
99            interactions: RwLock::new(HashMap::new()),
100            concurrency_policy: policy,
101            event_store,
102            event_bus,
103            config: ThreadRuntimeConfig::default(),
104        }
105    }
106
107    /// Create a new thread runtime with custom config
108    pub fn with_config(
109        thread: Thread,
110        event_store: Arc<dyn EventStore>,
111        config: ThreadRuntimeConfig,
112    ) -> Self {
113        Self::with_config_and_bus(
114            thread,
115            event_store,
116            config,
117            Arc::new(BroadcastEventBus::default()),
118        )
119    }
120
121    /// Create a new thread runtime with custom config and event bus
122    pub fn with_config_and_bus(
123        thread: Thread,
124        event_store: Arc<dyn EventStore>,
125        config: ThreadRuntimeConfig,
126        event_bus: Arc<dyn EventBus>,
127    ) -> Self {
128        Self {
129            thread: RwLock::new(thread),
130            interactions: RwLock::new(HashMap::new()),
131            concurrency_policy: Arc::new(DefaultConcurrencyPolicy),
132            event_store,
133            event_bus,
134            config,
135        }
136    }
137
138    /// Create a new thread runtime with custom policy and config
139    pub fn with_policy_and_config(
140        thread: Thread,
141        event_store: Arc<dyn EventStore>,
142        policy: Arc<dyn ConcurrencyPolicy>,
143        config: ThreadRuntimeConfig,
144    ) -> Self {
145        Self::with_policy_config_and_bus(
146            thread,
147            event_store,
148            policy,
149            config,
150            Arc::new(BroadcastEventBus::default()),
151        )
152    }
153
154    /// Create a new thread runtime with custom policy, config, and event bus
155    pub fn with_policy_config_and_bus(
156        thread: Thread,
157        event_store: Arc<dyn EventStore>,
158        policy: Arc<dyn ConcurrencyPolicy>,
159        config: ThreadRuntimeConfig,
160        event_bus: Arc<dyn EventBus>,
161    ) -> Self {
162        Self {
163            thread: RwLock::new(thread),
164            interactions: RwLock::new(HashMap::new()),
165            concurrency_policy: policy,
166            event_store,
167            event_bus,
168            config,
169        }
170    }
171
172    /// Get the thread ID
173    pub async fn thread_id(&self) -> ThreadId {
174        self.thread.read().await.id.clone()
175    }
176
177    /// Get the current running state
178    pub async fn running_state(&self) -> RunningState {
179        let interactions = self.interactions.read().await;
180        let active_count = interactions
181            .values()
182            .filter(|i| !i.state.is_terminal())
183            .count();
184        let is_processing = interactions
185            .values()
186            .any(|i| i.state == InteractionState::Active);
187        let is_waiting_user = interactions
188            .values()
189            .any(|i| i.state == InteractionState::WaitingUser);
190        let is_waiting_event = interactions
191            .values()
192            .any(|i| i.state == InteractionState::WaitingEvent);
193
194        RunningState {
195            active_count,
196            is_processing,
197            is_waiting_user,
198            is_waiting_event,
199        }
200    }
201
202    /// Handle a new event
203    pub async fn handle_event(&self, event: Event) -> Result<HandleEventResult, RuntimeError> {
204        self.validate_event(&event).await?;
205
206        // Get current running state
207        let running_state = self.running_state().await;
208
209        // Ask policy for decision
210        let decision = self.concurrency_policy.decide(&running_state, &event);
211
212        // Handle based on decision
213        match decision {
214            ConcurrencyDecision::InterruptAndStartNew => {
215                // Cancel all active interactions
216                self.cancel_all_active().await;
217
218                let interaction_id = match self.create_interaction_if_allowed().await {
219                    Ok(id) => id,
220                    Err(reason) => {
221                        self.persist_event(event).await?;
222                        return Ok(HandleEventResult::Rejected { reason });
223                    }
224                };
225
226                // Touch the thread
227                self.thread.write().await.touch();
228
229                // Store the event with the runtime-generated interaction_id
230                self.persist_event(event.with_interaction_id(&interaction_id))
231                    .await?;
232
233                Ok(HandleEventResult::Started { interaction_id })
234            }
235            ConcurrencyDecision::Reject { reason } => {
236                // Store the event as-is (no interaction created)
237                self.persist_event(event).await?;
238                Ok(HandleEventResult::Rejected { reason })
239            }
240            ConcurrencyDecision::Queue => {
241                // Queueing is not implemented yet. Reject explicitly to avoid silent drops.
242                self.persist_event(event).await?;
243                Ok(HandleEventResult::Rejected {
244                    reason: "Queue policy is configured but queue execution is not implemented"
245                        .to_string(),
246                })
247            }
248            ConcurrencyDecision::Parallel => {
249                let interaction_id = match self.create_interaction_if_allowed().await {
250                    Ok(id) => id,
251                    Err(reason) => {
252                        self.persist_event(event).await?;
253                        return Ok(HandleEventResult::Rejected { reason });
254                    }
255                };
256
257                // Touch the thread
258                self.thread.write().await.touch();
259
260                // Store the event with the runtime-generated interaction_id
261                self.persist_event(event.with_interaction_id(&interaction_id))
262                    .await?;
263
264                Ok(HandleEventResult::Started { interaction_id })
265            }
266            ConcurrencyDecision::MergeIntoRunning => {
267                // Find the active interaction and merge
268                let interactions = self.interactions.read().await;
269                let active_id = interactions
270                    .values()
271                    .find(|i| i.state == InteractionState::Active)
272                    .map(|i| i.id.clone());
273
274                if let Some(interaction_id) = active_id {
275                    // Store the event with the active interaction_id
276                    self.persist_event(event.with_interaction_id(&interaction_id))
277                        .await?;
278                    Ok(HandleEventResult::Merged { interaction_id })
279                } else {
280                    // No active interaction, start new
281                    drop(interactions);
282                    let interaction_id = match self.create_interaction_if_allowed().await {
283                        Ok(id) => id,
284                        Err(reason) => {
285                            self.persist_event(event).await?;
286                            return Ok(HandleEventResult::Rejected { reason });
287                        }
288                    };
289
290                    self.thread.write().await.touch();
291
292                    // Store the event with the runtime-generated interaction_id
293                    self.persist_event(event.with_interaction_id(&interaction_id))
294                        .await?;
295
296                    Ok(HandleEventResult::Started { interaction_id })
297                }
298            }
299        }
300    }
301
302    async fn create_interaction_if_allowed(&self) -> Result<InteractionId, String> {
303        let thread_id = self.thread_id().await;
304        let interaction = Interaction::new(&thread_id);
305        let interaction_id = interaction.id.clone();
306
307        let mut interactions = self.interactions.write().await;
308        let active_count = interactions
309            .values()
310            .filter(|i| !i.state.is_terminal())
311            .count();
312        if active_count >= self.config.max_interactions_per_thread {
313            return Err(format!(
314                "Maximum active interactions ({}) reached",
315                self.config.max_interactions_per_thread
316            ));
317        }
318
319        interactions.insert(interaction_id.clone(), interaction);
320        Ok(interaction_id)
321    }
322
323    async fn validate_event(&self, event: &Event) -> Result<(), RuntimeError> {
324        let expected_thread_id = self.thread_id().await;
325        let got_thread_id = event.thread_id();
326        if expected_thread_id != got_thread_id {
327            return Err(RuntimeError::InvalidEvent(format!(
328                "thread_id mismatch (expected {}, got {})",
329                expected_thread_id, got_thread_id
330            )));
331        }
332
333        if !payload_is_valid(event) {
334            return Err(RuntimeError::InvalidEvent(
335                "payload must not be null for user/external events".to_string(),
336            ));
337        }
338
339        Ok(())
340    }
341
342    /// Cancel all active interactions
343    pub async fn cancel_all_active(&self) {
344        let mut interactions = self.interactions.write().await;
345        for interaction in interactions.values_mut() {
346            if !interaction.state.is_terminal() {
347                interaction.cancel();
348            }
349        }
350
351        // Auto-cleanup if enabled
352        if self.config.auto_cleanup {
353            interactions.retain(|_, i| !i.state.is_terminal());
354        }
355    }
356
357    /// Get an interaction by ID
358    pub async fn get_interaction(&self, id: &str) -> Option<Interaction> {
359        let interactions = self.interactions.read().await;
360        let key: InteractionId = id.into();
361        interactions.get(&key).cloned()
362    }
363
364    /// Add a task to an interaction
365    pub async fn add_task_to_interaction(
366        &self,
367        id: &str,
368        task_id: TaskId,
369    ) -> Result<(), RuntimeError> {
370        let mut interactions = self.interactions.write().await;
371        let key: InteractionId = id.into();
372        if let Some(interaction) = interactions.get_mut(&key) {
373            interaction.add_task(task_id);
374            Ok(())
375        } else {
376            Err(RuntimeError::InteractionNotFound(id.to_string()))
377        }
378    }
379
380    /// Find a waiting interaction that can be resumed by this event.
381    pub async fn find_resume_interaction(&self, event: &Event) -> Option<InteractionId> {
382        let target_state = match event {
383            Event::UserInput { .. } => InteractionState::WaitingUser,
384            Event::ExternalEvent { .. } => InteractionState::WaitingEvent,
385            _ => return None,
386        };
387
388        let interactions = self.interactions.read().await;
389        interactions
390            .values()
391            .filter(|i| i.state == target_state)
392            .max_by_key(|i| i.started_at)
393            .map(|i| i.id.clone())
394    }
395
396    /// Append an event to an existing interaction and keep event/interaction IDs consistent.
397    pub async fn append_event_to_interaction(
398        &self,
399        interaction_id: &str,
400        event: Event,
401    ) -> Result<(), RuntimeError> {
402        self.validate_event(&event).await?;
403
404        let exists = {
405            let interactions = self.interactions.read().await;
406            let key: InteractionId = interaction_id.into();
407            interactions.contains_key(&key)
408        };
409        if !exists {
410            return Err(RuntimeError::InteractionNotFound(
411                interaction_id.to_string(),
412            ));
413        }
414
415        self.persist_event(event.with_interaction_id(interaction_id))
416            .await?;
417        self.thread.write().await.touch();
418        Ok(())
419    }
420
421    /// Subscribe to the realtime event stream.
422    pub fn subscribe_events(&self) -> tokio::sync::broadcast::Receiver<Event> {
423        self.event_bus.subscribe()
424    }
425
426    /// Mark a waiting interaction active so execution can continue.
427    pub async fn resume_interaction(&self, id: &str) -> Result<(), RuntimeError> {
428        let mut interactions = self.interactions.write().await;
429        let key: InteractionId = id.into();
430        let interaction = interactions
431            .get_mut(&key)
432            .ok_or_else(|| RuntimeError::InteractionNotFound(id.to_string()))?;
433        if interaction.state.is_terminal() {
434            return Err(RuntimeError::InvalidEvent(format!(
435                "interaction '{}' is terminal and cannot be resumed",
436                id
437            )));
438        }
439        interaction.resume();
440        Ok(())
441    }
442
443    /// Update an interaction's state
444    pub async fn update_interaction_state(
445        &self,
446        id: &str,
447        state: InteractionState,
448    ) -> Result<(), RuntimeError> {
449        let mut interactions = self.interactions.write().await;
450        let key: InteractionId = id.into();
451        if let Some(interaction) = interactions.get_mut(&key) {
452            interaction.set_state(state);
453            Ok(())
454        } else {
455            Err(RuntimeError::InteractionNotFound(id.to_string()))
456        }
457    }
458
459    /// Complete an interaction
460    pub async fn complete_interaction(&self, id: &str) -> Result<(), RuntimeError> {
461        self.update_interaction_state(id, InteractionState::Completed)
462            .await
463    }
464
465    /// Fail an interaction
466    pub async fn fail_interaction(&self, id: &str) -> Result<(), RuntimeError> {
467        self.update_interaction_state(id, InteractionState::Failed)
468            .await
469    }
470
471    /// Get all active interaction IDs
472    pub async fn active_interaction_ids(&self) -> Vec<InteractionId> {
473        let interactions = self.interactions.read().await;
474        interactions
475            .values()
476            .filter(|i| !i.state.is_terminal())
477            .map(|i| i.id.clone())
478            .collect()
479    }
480
481    /// Query recent events for this thread
482    pub async fn query_history(&self, limit: usize) -> Result<Vec<Event>, RuntimeError> {
483        let thread_id = self.thread_id().await;
484        let events = if limit == 0 {
485            self.event_store
486                .query_by_thread(thread_id.as_str())
487                .await
488                .map_err(|e| RuntimeError::StoreError(e.to_string()))?
489        } else {
490            self.event_store
491                .query_by_thread_with_limit(thread_id.as_str(), limit)
492                .await
493                .map_err(|e| RuntimeError::StoreError(e.to_string()))?
494        };
495        Ok(events)
496    }
497
498    /// Cleanup completed interactions
499    pub async fn cleanup_completed(&self) {
500        let mut interactions = self.interactions.write().await;
501        interactions.retain(|_, i| !i.state.is_terminal());
502    }
503
504    async fn persist_event(&self, event: Event) -> Result<(), RuntimeError> {
505        self.event_store
506            .append(event.clone())
507            .await
508            .map_err(|e| RuntimeError::StoreError(e.to_string()))?;
509        self.event_bus
510            .publish(event)
511            .await
512            .map_err(|e| RuntimeError::Internal(format!("event bus publish failed: {}", e)))?;
513        Ok(())
514    }
515}
516
517fn payload_is_valid(event: &Event) -> bool {
518    match event {
519        Event::UserInput { payload, .. } | Event::ExternalEvent { payload, .. } => {
520            !payload.is_null()
521        }
522        _ => true,
523    }
524}
525
526/// Result of handling an event
527#[derive(Debug, Clone)]
528pub enum HandleEventResult {
529    /// A new interaction was started
530    Started {
531        /// ID of the new interaction
532        interaction_id: InteractionId,
533    },
534    /// The event was rejected
535    Rejected {
536        /// Reason for rejection
537        reason: String,
538    },
539    /// The event was queued for later processing
540    Queued,
541    /// The event was merged into an existing interaction
542    Merged {
543        /// ID of the interaction the event was merged into
544        interaction_id: InteractionId,
545    },
546}
547
548/// Runtime errors
549#[derive(Debug, thiserror::Error)]
550pub enum RuntimeError {
551    #[error("Store error: {0}")]
552    StoreError(String),
553
554    #[error("Interaction not found: {0}")]
555    InteractionNotFound(String),
556
557    #[error("Thread not found: {0}")]
558    ThreadNotFound(String),
559
560    #[error("Internal error: {0}")]
561    Internal(String),
562
563    #[error("Invalid event: {0}")]
564    InvalidEvent(String),
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::concurrency::{ParallelConcurrencyPolicy, QueueConcurrencyPolicy};
571    use chrono::{Duration, Utc};
572    use orchestral_core::store::{BroadcastEventBus, InMemoryEventStore};
573    use serde_json::json;
574
575    #[test]
576    fn test_find_resume_interaction_prefers_latest_waiting_user() {
577        tokio_test::block_on(async {
578            let thread_id = "thread-1";
579            let runtime = ThreadRuntime::new(
580                Thread::with_id(thread_id),
581                Arc::new(InMemoryEventStore::new()),
582            );
583
584            {
585                let mut interactions = runtime.interactions.write().await;
586
587                let mut older = Interaction::with_id("older", thread_id);
588                older.set_state(InteractionState::WaitingUser);
589                older.started_at = Utc::now() - Duration::seconds(10);
590                interactions.insert(older.id.clone(), older);
591
592                let mut newer = Interaction::with_id("newer", thread_id);
593                newer.set_state(InteractionState::WaitingUser);
594                newer.started_at = Utc::now();
595                interactions.insert(newer.id.clone(), newer);
596            }
597
598            let event = Event::user_input(thread_id, "ignored", json!({"message":"resume"}));
599            let found = runtime.find_resume_interaction(&event).await;
600            assert_eq!(found.as_ref().map(|id| id.as_str()), Some("newer"));
601        });
602    }
603
604    #[test]
605    fn test_append_event_to_interaction_rewrites_user_interaction_id() {
606        tokio_test::block_on(async {
607            let thread_id = "thread-1";
608            let runtime = ThreadRuntime::new(
609                Thread::with_id(thread_id),
610                Arc::new(InMemoryEventStore::new()),
611            );
612
613            {
614                let mut interactions = runtime.interactions.write().await;
615                interactions.insert("target".into(), Interaction::with_id("target", thread_id));
616            }
617
618            let event = Event::user_input(thread_id, "wrong", json!({"text":"hello"}));
619            runtime
620                .append_event_to_interaction("target", event)
621                .await
622                .unwrap();
623
624            let events = runtime.query_history(0).await.unwrap();
625            assert_eq!(events.len(), 1);
626            match &events[0] {
627                Event::UserInput { interaction_id, .. } => {
628                    assert_eq!(interaction_id.as_str(), "target");
629                }
630                _ => panic!("expected user_input event"),
631            }
632        });
633    }
634
635    #[test]
636    fn test_handle_event_publishes_to_event_bus() {
637        tokio_test::block_on(async {
638            let thread_id = "thread-1";
639            let runtime = ThreadRuntime::new_with_bus(
640                Thread::with_id(thread_id),
641                Arc::new(InMemoryEventStore::new()),
642                Arc::new(BroadcastEventBus::new(16)),
643            );
644            let mut sub = runtime.subscribe_events();
645
646            let event = Event::user_input(thread_id, "cli", json!({"message":"hello"}));
647            let result = runtime.handle_event(event).await.unwrap();
648            assert!(matches!(result, HandleEventResult::Started { .. }));
649
650            let published = sub.recv().await.expect("published event");
651            match published {
652                Event::UserInput {
653                    interaction_id,
654                    payload,
655                    ..
656                } => {
657                    assert_ne!(interaction_id.as_str(), "cli");
658                    assert_eq!(payload["message"], "hello");
659                }
660                _ => panic!("expected user_input event"),
661            }
662        });
663    }
664
665    #[test]
666    fn test_rejects_when_max_active_interactions_reached() {
667        tokio_test::block_on(async {
668            let thread_id = "thread-max";
669            let runtime = ThreadRuntime::with_policy_and_config(
670                Thread::with_id(thread_id),
671                Arc::new(InMemoryEventStore::new()),
672                Arc::new(ParallelConcurrencyPolicy::new(10)),
673                ThreadRuntimeConfig {
674                    max_interactions_per_thread: 1,
675                    auto_cleanup: false,
676                },
677            );
678
679            let first = Event::user_input(thread_id, "a", json!({"message":"first"}));
680            let first_result = runtime.handle_event(first).await.unwrap();
681            assert!(matches!(first_result, HandleEventResult::Started { .. }));
682
683            let second = Event::user_input(thread_id, "b", json!({"message":"second"}));
684            let second_result = runtime.handle_event(second).await.unwrap();
685            match second_result {
686                HandleEventResult::Rejected { reason } => {
687                    assert!(reason.contains("Maximum active interactions (1) reached"));
688                }
689                other => panic!("expected rejected result, got {:?}", other),
690            }
691        });
692    }
693
694    #[test]
695    fn test_queue_policy_returns_rejected_not_queued() {
696        tokio_test::block_on(async {
697            let thread_id = "thread-queue";
698            let runtime = ThreadRuntime::with_policy(
699                Thread::with_id(thread_id),
700                Arc::new(InMemoryEventStore::new()),
701                Arc::new(QueueConcurrencyPolicy),
702            );
703
704            let first = Event::user_input(thread_id, "a", json!({"message":"first"}));
705            let first_result = runtime.handle_event(first).await.unwrap();
706            assert!(matches!(first_result, HandleEventResult::Started { .. }));
707
708            let second = Event::user_input(thread_id, "b", json!({"message":"second"}));
709            let second_result = runtime.handle_event(second).await.unwrap();
710            match second_result {
711                HandleEventResult::Rejected { reason } => {
712                    assert!(reason.contains("Queue policy"));
713                }
714                other => panic!("expected rejected result, got {:?}", other),
715            }
716        });
717    }
718}