Skip to main content

rusty_genius_stem/
lib.rs

1use anyhow::Result;
2use facecrab::AssetAuthority;
3use futures::channel::mpsc;
4use futures::sink::SinkExt;
5use futures::StreamExt;
6use rusty_genius_core::protocol::{AssetEvent, BrainstemInput, BrainstemOutput};
7use rusty_genius_cortex::{create_engine, Engine};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub enum CortexStrategy {
12    Immediate,
13    HibernateAfter(Duration),
14    KeepAlive,
15}
16
17pub struct Orchestrator {
18    engine: Box<dyn Engine>,
19    asset_authority: AssetAuthority,
20    strategy: CortexStrategy,
21    last_activity: Instant,
22}
23
24impl Orchestrator {
25    pub async fn new() -> Result<Self> {
26        // In a real app, we might need configuration here
27        let engine = create_engine().await;
28        let asset_authority = AssetAuthority::new()?;
29        Ok(Self {
30            engine,
31            asset_authority,
32            strategy: CortexStrategy::HibernateAfter(Duration::from_secs(300)), // Default 5 mins
33            last_activity: Instant::now(),
34        })
35    }
36
37    /// Run the main event loop
38    /// Consumes BrainstemInput stream, produces BrainstemOutput stream
39    pub async fn run(
40        &mut self,
41        mut input_rx: mpsc::Receiver<BrainstemInput>,
42        mut output_tx: mpsc::Sender<BrainstemOutput>,
43    ) -> Result<()> {
44        loop {
45            // Determine timeout based on strategy
46            let timeout_duration = match self.strategy {
47                CortexStrategy::HibernateAfter(duration) => Some(duration),
48                CortexStrategy::Immediate => Some(Duration::ZERO), // Or very small
49                CortexStrategy::KeepAlive => None,
50            };
51
52            let next_activity = if let Some(d) = timeout_duration {
53                // Calculate when we should hibernate if no activity
54                let elapsed = self.last_activity.elapsed();
55                if elapsed >= d {
56                    // Time to hibernate!
57                    if let Err(e) = self.engine.unload_model().await {
58                        eprintln!("Failed to hibernate engine: {}", e);
59                    }
60                    // Wait for next message indefinitely since we are hibernated/unloaded
61                    None
62                } else {
63                    Some(d - elapsed)
64                }
65            } else {
66                None
67            };
68
69            let msg_option = if let Some(wait_time) = next_activity {
70                match async_std::future::timeout(wait_time, input_rx.next()).await {
71                    Ok(msg) => msg,
72                    Err(_) => {
73                        // Timeout expired, loop back to check (should trigger hibernation)
74                        continue;
75                    }
76                }
77            } else {
78                // Wait indefinitely
79                input_rx.next().await
80            };
81
82            match msg_option {
83                Some(msg) => {
84                    self.last_activity = Instant::now(); // Update activity
85                    match msg {
86                        BrainstemInput::LoadModel(name_or_path) => {
87                            let mut events =
88                                self.asset_authority.ensure_model_stream(&name_or_path);
89                            let mut path_to_load = name_or_path.clone();
90
91                            while let Some(event) = events.next().await {
92                                if let AssetEvent::Complete(path) = &event {
93                                    path_to_load = path.clone();
94                                }
95                                if let Err(_) = output_tx.send(BrainstemOutput::Asset(event)).await
96                                {
97                                    break;
98                                }
99                            }
100
101                            // Finally load into engine
102                            if let Err(e) = self.engine.load_model(&path_to_load).await {
103                                let _ = output_tx.send(BrainstemOutput::Error(e.to_string())).await;
104                            }
105                        }
106                        BrainstemInput::Infer { prompt, config: _ } => {
107                            // Trigger inference
108                            match self.engine.infer(&prompt).await {
109                                Ok(mut event_rx) => {
110                                    // Forward events to output
111                                    while let Some(event_res) = event_rx.next().await {
112                                        match event_res {
113                                            Ok(event) => {
114                                                if let Err(_) = output_tx
115                                                    .send(BrainstemOutput::Event(event))
116                                                    .await
117                                                {
118                                                    break; // Receiver dropped
119                                                }
120                                            }
121                                            Err(e) => {
122                                                let _ = output_tx
123                                                    .send(BrainstemOutput::Error(e.to_string()))
124                                                    .await;
125                                            }
126                                        }
127                                    }
128                                }
129                                Err(e) => {
130                                    let _ =
131                                        output_tx.send(BrainstemOutput::Error(e.to_string())).await;
132                                }
133                            }
134                        }
135                        BrainstemInput::Stop => {
136                            break;
137                        }
138                    }
139                }
140                None => {
141                    break; // Channel closed
142                }
143            }
144        }
145        Ok(())
146    }
147}