agent_stream_kit/
askit.rs

1use std::sync::atomic::AtomicUsize;
2use std::sync::{Arc, Mutex};
3
4use tokio::sync::{Mutex as AsyncMutex, mpsc};
5
6use crate::FnvIndexMap;
7use crate::agent::{Agent, AgentMessage, AgentStatus, agent_new};
8use crate::config::{AgentConfigs, AgentConfigsMap};
9use crate::context::AgentContext;
10use crate::definition::{AgentConfigSpecs, AgentDefinition, AgentDefinitions};
11use crate::error::AgentError;
12use crate::message::{self, AgentEventMessage};
13use crate::registry;
14use crate::spec::{self, AgentSpec, AgentStreamSpec, ChannelSpec};
15use crate::stream::{AgentStream, AgentStreamInfo, AgentStreams};
16use crate::value::AgentValue;
17
18const MESSAGE_LIMIT: usize = 1024;
19
20#[derive(Clone)]
21pub struct ASKit {
22    // agent id -> agent
23    pub(crate) agents: Arc<Mutex<FnvIndexMap<String, Arc<AsyncMutex<Box<dyn Agent>>>>>>,
24
25    // agent id -> sender
26    pub(crate) agent_txs: Arc<Mutex<FnvIndexMap<String, AgentMessageSender>>>,
27
28    // board name -> [board out agent id]
29    pub(crate) board_out_agents: Arc<Mutex<FnvIndexMap<String, Vec<String>>>>,
30
31    // board name -> value
32    pub(crate) board_value: Arc<Mutex<FnvIndexMap<String, AgentValue>>>,
33
34    // source agent id -> [target agent id / source handle / target handle]
35    pub(crate) channels: Arc<Mutex<FnvIndexMap<String, Vec<(String, String, String)>>>>,
36
37    // agent def name -> agent definition
38    pub(crate) defs: Arc<Mutex<AgentDefinitions>>,
39
40    // agent streams (stream id -> stream)
41    pub(crate) streams: Arc<Mutex<AgentStreams>>,
42
43    // agent def name -> config
44    pub(crate) global_configs_map: Arc<Mutex<FnvIndexMap<String, AgentConfigs>>>,
45
46    // message sender
47    pub(crate) tx: Arc<Mutex<Option<mpsc::Sender<AgentEventMessage>>>>,
48
49    // observers
50    pub(crate) observers: Arc<Mutex<FnvIndexMap<usize, Box<dyn ASKitObserver + Sync + Send>>>>,
51}
52
53impl ASKit {
54    pub fn new() -> Self {
55        Self {
56            agents: Default::default(),
57            agent_txs: Default::default(),
58            board_out_agents: Default::default(),
59            board_value: Default::default(),
60            channels: Default::default(),
61            defs: Default::default(),
62            streams: Default::default(),
63            global_configs_map: Default::default(),
64            tx: Arc::new(Mutex::new(None)),
65            observers: Default::default(),
66        }
67    }
68
69    pub(crate) fn tx(&self) -> Result<mpsc::Sender<AgentEventMessage>, AgentError> {
70        self.tx
71            .lock()
72            .unwrap()
73            .clone()
74            .ok_or(AgentError::TxNotInitialized)
75    }
76
77    pub fn init() -> Result<Self, AgentError> {
78        let askit = Self::new();
79        askit.register_agents();
80        Ok(askit)
81    }
82
83    fn register_agents(&self) {
84        registry::register_inventory_agents(self);
85    }
86
87    pub async fn ready(&self) -> Result<(), AgentError> {
88        self.spawn_message_loop().await?;
89        self.start_agent_streams_on_start().await?;
90        Ok(())
91    }
92
93    pub fn quit(&self) {
94        let mut tx_lock = self.tx.lock().unwrap();
95        *tx_lock = None;
96    }
97
98    pub fn register_agent_definiton(&self, def: AgentDefinition) {
99        let def_name = def.name.clone();
100        let def_global_configs = def.global_configs.clone();
101
102        let mut defs = self.defs.lock().unwrap();
103        defs.insert(def.name.clone(), def);
104
105        // if there is a global config, set it
106        if let Some(def_global_configs) = def_global_configs {
107            let mut new_configs = AgentConfigs::default();
108            for (key, config_entry) in def_global_configs.iter() {
109                new_configs.set(key.clone(), config_entry.value.clone());
110            }
111            self.set_global_configs(def_name, new_configs);
112        }
113    }
114
115    pub fn get_agent_definitions(&self) -> AgentDefinitions {
116        let defs = self.defs.lock().unwrap();
117        defs.clone()
118    }
119
120    pub fn get_agent_definition(&self, def_name: &str) -> Option<AgentDefinition> {
121        let defs = self.defs.lock().unwrap();
122        defs.get(def_name).cloned()
123    }
124
125    pub fn get_agent_config_specs(&self, def_name: &str) -> Option<AgentConfigSpecs> {
126        let defs = self.defs.lock().unwrap();
127        let Some(def) = defs.get(def_name) else {
128            return None;
129        };
130        def.configs.clone()
131    }
132
133    pub fn get_agent_spec(&self, agent_id: &str) -> Option<AgentSpec> {
134        let agents = self.agents.lock().unwrap();
135        let Some(agent) = agents.get(agent_id) else {
136            return None;
137        };
138        let agent = agent.blocking_lock();
139        Some(agent.spec().clone())
140    }
141
142    // streams
143
144    /// Get info of the agent stream by id.
145    pub fn get_agent_stream_info(&self, id: &str) -> Option<AgentStreamInfo> {
146        let streams = self.streams.lock().unwrap();
147        streams.get(id).map(|stream| stream.into())
148    }
149
150    /// Get infos of all agent streams.
151    pub fn get_agent_stream_infos(&self) -> Vec<AgentStreamInfo> {
152        let streams = self.streams.lock().unwrap();
153        streams.values().map(|s| s.into()).collect()
154    }
155
156    /// Get the agent stream spec by id.
157    pub fn get_agent_stream_spec(&self, id: &str) -> Option<AgentStreamSpec> {
158        let streams = self.streams.lock().unwrap();
159        streams.get(id).map(|stream| stream.spec().clone())
160    }
161
162    /// Set the agent stream spec by id.
163    pub fn set_agent_stream_spec(&self, id: &str, spec: AgentStreamSpec) -> Result<(), AgentError> {
164        let mut streams = self.streams.lock().unwrap();
165        let Some(stream) = streams.get_mut(id) else {
166            return Err(AgentError::StreamNotFound(id.to_string()));
167        };
168        *stream.spec_mut() = spec;
169        Ok(())
170    }
171
172    /// Create a new agent stream with the given name.
173    /// If the name already exists, a unique name will be generated by appending a number suffix.
174    /// Returns the id of the new agent stream.
175    pub fn new_agent_stream(&self, name: &str) -> Result<String, AgentError> {
176        if !is_valid_stream_name(name) {
177            return Err(AgentError::InvalidStreamName(name.into()));
178        }
179        let new_name = self.unique_stream_name(name);
180        let spec = AgentStreamSpec::default();
181        let id = self.add_agent_stream(new_name, spec)?;
182        Ok(id)
183    }
184
185    pub fn rename_agent_stream(&self, id: &str, new_name: &str) -> Result<String, AgentError> {
186        if !is_valid_stream_name(new_name) {
187            return Err(AgentError::InvalidStreamName(new_name.into()));
188        }
189
190        // check if the new name is already used
191        let new_name = self.unique_stream_name(new_name);
192
193        let mut streams = self.streams.lock().unwrap();
194
195        // remove the original stream
196        let Some(mut stream) = streams.swap_remove(id) else {
197            return Err(AgentError::RenameStreamFailed(id.into()));
198        };
199
200        // insert renamed stream
201        stream.set_name(new_name.clone());
202        streams.insert(stream.id().to_string(), stream);
203        Ok(new_name)
204    }
205
206    pub fn unique_stream_name(&self, name: &str) -> String {
207        let mut new_name = name.trim().to_string();
208        let mut i = 2;
209        let streams = self.streams.lock().unwrap();
210        while streams.values().any(|stream| stream.name() == new_name) {
211            new_name = format!("{}{}", name, i);
212            i += 1;
213        }
214        new_name
215    }
216
217    pub fn add_agent_stream(
218        &self,
219        name: String,
220        spec: AgentStreamSpec,
221    ) -> Result<String, AgentError> {
222        let stream = AgentStream::new(name, spec);
223        let id = stream.id().to_string();
224
225        // add agents
226        for agent in &stream.spec().agents {
227            if let Err(e) = self.add_agent_internal(id.clone(), agent.clone()) {
228                log::error!("Failed to add_agent {}: {}", agent.id, e);
229            }
230        }
231
232        // add channels
233        for channel in &stream.spec().channels {
234            self.add_channel_internal(channel.clone())
235                .unwrap_or_else(|e| {
236                    log::error!("Failed to add_channel {}: {}", channel.source, e);
237                });
238        }
239
240        // add the given stream into streams
241        let mut streams = self.streams.lock().unwrap();
242        if streams.contains_key(&id) {
243            return Err(AgentError::DuplicateId(id.into()));
244        }
245        streams.insert(id.to_string(), stream);
246
247        Ok(id)
248    }
249
250    pub async fn remove_agent_stream(&self, id: &str) -> Result<(), AgentError> {
251        let mut stream = {
252            let mut streams = self.streams.lock().unwrap();
253            let Some(stream) = streams.swap_remove(id) else {
254                return Err(AgentError::StreamNotFound(id.to_string()));
255            };
256            stream
257        };
258
259        stream.stop(self).await?;
260
261        // Remove all agents and channels associated with the stream
262        for agent in &stream.spec().agents {
263            self.remove_agent_internal(&agent.id).await?;
264        }
265        for channel in &stream.spec().channels {
266            self.remove_channel_internal(channel);
267        }
268
269        Ok(())
270    }
271
272    pub fn copy_sub_stream(
273        &self,
274        agents: &Vec<AgentSpec>,
275        channels: &Vec<ChannelSpec>,
276    ) -> (Vec<AgentSpec>, Vec<ChannelSpec>) {
277        spec::copy_sub_stream(agents, channels)
278    }
279
280    pub async fn start_agent_stream(&self, id: &str) -> Result<(), AgentError> {
281        let mut stream = {
282            let mut streams = self.streams.lock().unwrap();
283            let Some(stream) = streams.swap_remove(id) else {
284                return Err(AgentError::StreamNotFound(id.to_string()));
285            };
286            stream
287        };
288
289        stream.start(self).await?;
290
291        let mut streams = self.streams.lock().unwrap();
292        streams.insert(id.to_string(), stream);
293        Ok(())
294    }
295
296    pub async fn stop_agent_stream(&self, id: &str) -> Result<(), AgentError> {
297        let mut stream = {
298            let mut streams = self.streams.lock().unwrap();
299            let Some(stream) = streams.swap_remove(id) else {
300                return Err(AgentError::StreamNotFound(id.to_string()));
301            };
302            stream
303        };
304
305        stream.stop(self).await?;
306
307        let mut streams = self.streams.lock().unwrap();
308        streams.insert(id.to_string(), stream);
309        Ok(())
310    }
311
312    // Agents
313
314    /// Create a new agent spec from the given agent definition name.
315    pub fn new_agent_spec(&self, def_name: &str) -> Result<AgentSpec, AgentError> {
316        let def = self
317            .get_agent_definition(def_name)
318            .ok_or_else(|| AgentError::AgentDefinitionNotFound(def_name.to_string()))?;
319        Ok(AgentSpec::from_def(&def))
320    }
321
322    /// Add an agent to the specified stream.
323    pub fn add_agent(&self, stream_id: String, spec: AgentSpec) -> Result<(), AgentError> {
324        let mut streams = self.streams.lock().unwrap();
325        let Some(stream) = streams.get_mut(&stream_id) else {
326            return Err(AgentError::StreamNotFound(stream_id.to_string()));
327        };
328        self.add_agent_internal(stream_id, spec.clone())?;
329        stream.spec_mut().add_agent(spec.clone());
330        Ok(())
331    }
332
333    fn add_agent_internal(&self, stream_id: String, spec: AgentSpec) -> Result<(), AgentError> {
334        let mut agents = self.agents.lock().unwrap();
335        if agents.contains_key(&spec.id) {
336            return Err(AgentError::AgentAlreadyExists(spec.id.to_string()));
337        }
338        let spec_id = spec.id.clone();
339        let mut agent = agent_new(self.clone(), spec_id.clone(), spec)?;
340        agent.set_stream_id(stream_id);
341        agents.insert(spec_id, Arc::new(AsyncMutex::new(agent)));
342        Ok(())
343    }
344
345    /// Get the agent by id.
346    pub fn get_agent(&self, agent_id: &str) -> Option<Arc<AsyncMutex<Box<dyn Agent>>>> {
347        let agents = self.agents.lock().unwrap();
348        agents.get(agent_id).cloned()
349    }
350
351    pub fn add_channel(&self, stream_id: &str, channel: ChannelSpec) -> Result<(), AgentError> {
352        let mut streams = self.streams.lock().unwrap();
353        let Some(stream) = streams.get_mut(stream_id) else {
354            return Err(AgentError::StreamNotFound(stream_id.to_string()));
355        };
356        stream.spec_mut().add_channels(channel.clone());
357        self.add_channel_internal(channel)?;
358        Ok(())
359    }
360
361    fn add_channel_internal(&self, channel: ChannelSpec) -> Result<(), AgentError> {
362        // check if the source agent exists
363        {
364            let agents = self.agents.lock().unwrap();
365            if !agents.contains_key(&channel.source) {
366                return Err(AgentError::SourceAgentNotFound(channel.source.to_string()));
367            }
368        }
369
370        // check if handles are valid
371        if channel.source_handle.is_empty() {
372            return Err(AgentError::EmptySourceHandle);
373        }
374        if channel.target_handle.is_empty() {
375            return Err(AgentError::EmptyTargetHandle);
376        }
377
378        let mut channels = self.channels.lock().unwrap();
379        if let Some(targets) = channels.get_mut(&channel.source) {
380            if targets
381                .iter()
382                .any(|(target, source_handle, target_handle)| {
383                    *target == channel.target
384                        && *source_handle == channel.source_handle
385                        && *target_handle == channel.target_handle
386                })
387            {
388                return Err(AgentError::ChannelAlreadyExists);
389            }
390            targets.push((channel.target, channel.source_handle, channel.target_handle));
391        } else {
392            channels.insert(
393                channel.source,
394                vec![(channel.target, channel.source_handle, channel.target_handle)],
395            );
396        }
397        Ok(())
398    }
399
400    pub async fn remove_agent(&self, stream_id: &str, agent_id: &str) -> Result<(), AgentError> {
401        {
402            let mut streams = self.streams.lock().unwrap();
403            let Some(stream) = streams.get_mut(stream_id) else {
404                return Err(AgentError::StreamNotFound(stream_id.to_string()));
405            };
406            stream.spec_mut().remove_agent(agent_id);
407        }
408        self.remove_agent_internal(agent_id).await?;
409        Ok(())
410    }
411
412    async fn remove_agent_internal(&self, agent_id: &str) -> Result<(), AgentError> {
413        self.stop_agent(agent_id).await?;
414
415        // remove from channels
416        {
417            let mut channels = self.channels.lock().unwrap();
418            let mut sources_to_remove = Vec::new();
419            for (source, targets) in channels.iter_mut() {
420                targets.retain(|(target, _, _)| target != agent_id);
421                if targets.is_empty() {
422                    sources_to_remove.push(source.clone());
423                }
424            }
425            for source in sources_to_remove {
426                channels.swap_remove(&source);
427            }
428            channels.swap_remove(agent_id);
429        }
430
431        // remove from agents
432        {
433            let mut agents = self.agents.lock().unwrap();
434            agents.swap_remove(agent_id);
435        }
436
437        Ok(())
438    }
439
440    pub fn remove_channel(&self, stream_id: &str, channel_id: &str) -> Result<(), AgentError> {
441        let mut stream = {
442            let mut streams = self.streams.lock().unwrap();
443            let Some(stream) = streams.swap_remove(stream_id) else {
444                return Err(AgentError::StreamNotFound(stream_id.to_string()));
445            };
446            stream
447        };
448
449        let Some(channel) = stream.spec_mut().remove_channel(channel_id) else {
450            return Err(AgentError::ChannelNotFound(channel_id.to_string()));
451        };
452        self.remove_channel_internal(&channel);
453        Ok(())
454    }
455
456    fn remove_channel_internal(&self, channel: &ChannelSpec) {
457        let mut channels = self.channels.lock().unwrap();
458        if let Some(targets) = channels.get_mut(&channel.source) {
459            targets.retain(|(target, source_handle, target_handle)| {
460                *target != channel.target
461                    || *source_handle != channel.source_handle
462                    || *target_handle != channel.target_handle
463            });
464            if targets.is_empty() {
465                channels.swap_remove(&channel.source);
466            }
467        }
468    }
469
470    pub async fn start_agent(&self, agent_id: &str) -> Result<(), AgentError> {
471        let agent = {
472            let agents = self.agents.lock().unwrap();
473            let Some(a) = agents.get(agent_id) else {
474                return Err(AgentError::AgentNotFound(agent_id.to_string()));
475            };
476            a.clone()
477        };
478        let def_name = {
479            let agent = agent.lock().await;
480            agent.def_name().to_string()
481        };
482        let uses_native_thread = {
483            let defs = self.defs.lock().unwrap();
484            let Some(def) = defs.get(&def_name) else {
485                return Err(AgentError::AgentDefinitionNotFound(agent_id.to_string()));
486            };
487            def.native_thread
488        };
489        let agent_status = {
490            let agent = agent.lock().await;
491            agent.status().clone()
492        };
493        if agent_status == AgentStatus::Init {
494            log::info!("Starting agent {}", agent_id);
495
496            if uses_native_thread {
497                let (tx, rx) = std::sync::mpsc::channel();
498
499                {
500                    let mut agent_txs = self.agent_txs.lock().unwrap();
501                    agent_txs.insert(agent_id.to_string(), AgentMessageSender::Sync(tx.clone()));
502                };
503
504                let agent_id = agent_id.to_string();
505                std::thread::spawn(async move || {
506                    if let Err(e) = agent.lock().await.start().await {
507                        log::error!("Failed to start agent {}: {}", agent_id, e);
508                    }
509
510                    while let Ok(message) = rx.recv() {
511                        match message {
512                            AgentMessage::Input { ctx, pin, value } => {
513                                agent
514                                    .lock()
515                                    .await
516                                    .process(ctx, pin, value)
517                                    .await
518                                    .unwrap_or_else(|e| {
519                                        log::error!("Process Error {}: {}", agent_id, e);
520                                    });
521                            }
522                            AgentMessage::Config { configs } => {
523                                agent.lock().await.set_configs(configs).unwrap_or_else(|e| {
524                                    log::error!("Config Error {}: {}", agent_id, e);
525                                });
526                            }
527                            AgentMessage::Stop => {
528                                break;
529                            }
530                        }
531                    }
532                });
533            } else {
534                let (tx, mut rx) = mpsc::channel(MESSAGE_LIMIT);
535
536                {
537                    let mut agent_txs = self.agent_txs.lock().unwrap();
538                    agent_txs.insert(agent_id.to_string(), AgentMessageSender::Async(tx.clone()));
539                };
540
541                let agent_id = agent_id.to_string();
542                tokio::spawn(async move {
543                    {
544                        let mut agent_guard = agent.lock().await;
545                        if let Err(e) = agent_guard.start().await {
546                            log::error!("Failed to start agent {}: {}", agent_id, e);
547                        }
548                    }
549
550                    while let Some(message) = rx.recv().await {
551                        match message {
552                            AgentMessage::Input { ctx, pin, value } => {
553                                agent
554                                    .lock()
555                                    .await
556                                    .process(ctx, pin, value)
557                                    .await
558                                    .unwrap_or_else(|e| {
559                                        log::error!("Process Error {}: {}", agent_id, e);
560                                    });
561                            }
562                            AgentMessage::Config { configs } => {
563                                agent.lock().await.set_configs(configs).unwrap_or_else(|e| {
564                                    log::error!("Config Error {}: {}", agent_id, e);
565                                });
566                            }
567                            AgentMessage::Stop => {
568                                rx.close();
569                                return;
570                            }
571                        }
572                    }
573                });
574                tokio::task::yield_now().await;
575            }
576        }
577        Ok(())
578    }
579
580    pub async fn stop_agent(&self, agent_id: &str) -> Result<(), AgentError> {
581        let agent = {
582            let agents = self.agents.lock().unwrap();
583            let Some(a) = agents.get(agent_id) else {
584                return Err(AgentError::AgentNotFound(agent_id.to_string()));
585            };
586            a.clone()
587        };
588
589        let agent_status = {
590            let agent = agent.lock().await;
591            agent.status().clone()
592        };
593        if agent_status == AgentStatus::Start {
594            log::info!("Stopping agent {}", agent_id);
595
596            {
597                let mut agent_txs = self.agent_txs.lock().unwrap();
598                if let Some(tx) = agent_txs.swap_remove(agent_id) {
599                    match tx {
600                        AgentMessageSender::Sync(tx) => {
601                            tx.send(AgentMessage::Stop).unwrap_or_else(|e| {
602                                log::error!(
603                                    "Failed to send stop message to agent {}: {}",
604                                    agent_id,
605                                    e
606                                );
607                            });
608                        }
609                        AgentMessageSender::Async(tx) => {
610                            tx.try_send(AgentMessage::Stop).unwrap_or_else(|e| {
611                                log::error!(
612                                    "Failed to send stop message to agent {}: {}",
613                                    agent_id,
614                                    e
615                                );
616                            });
617                        }
618                    }
619                }
620            }
621
622            agent.lock().await.stop().await?;
623        }
624
625        Ok(())
626    }
627
628    pub async fn set_agent_configs(
629        &self,
630        agent_id: String,
631        configs: AgentConfigs,
632    ) -> Result<(), AgentError> {
633        let agent = {
634            let agents = self.agents.lock().unwrap();
635            let Some(a) = agents.get(&agent_id) else {
636                return Err(AgentError::AgentNotFound(agent_id.to_string()));
637            };
638            a.clone()
639        };
640
641        let agent_status = {
642            let agent = agent.lock().await;
643            agent.status().clone()
644        };
645        if agent_status == AgentStatus::Init {
646            agent.lock().await.set_configs(configs.clone())?;
647        } else if agent_status == AgentStatus::Start {
648            let tx = {
649                let agent_txs = self.agent_txs.lock().unwrap();
650                let Some(tx) = agent_txs.get(&agent_id) else {
651                    return Err(AgentError::AgentTxNotFound(agent_id.to_string()));
652                };
653                tx.clone()
654            };
655            let message = AgentMessage::Config { configs };
656            match tx {
657                AgentMessageSender::Sync(tx) => {
658                    tx.send(message).map_err(|_| {
659                        AgentError::SendMessageFailed("Failed to send config message".to_string())
660                    })?;
661                }
662                AgentMessageSender::Async(tx) => {
663                    tx.send(message).await.map_err(|_| {
664                        AgentError::SendMessageFailed("Failed to send config message".to_string())
665                    })?;
666                }
667            }
668        }
669        Ok(())
670    }
671
672    pub fn get_global_configs(&self, def_name: &str) -> Option<AgentConfigs> {
673        let global_configs_map = self.global_configs_map.lock().unwrap();
674        global_configs_map.get(def_name).cloned()
675    }
676
677    pub fn set_global_configs(&self, def_name: String, configs: AgentConfigs) {
678        let mut global_configs_map = self.global_configs_map.lock().unwrap();
679
680        let Some(existing_configs) = global_configs_map.get_mut(&def_name) else {
681            global_configs_map.insert(def_name, configs);
682            return;
683        };
684
685        for (key, value) in configs {
686            existing_configs.set(key, value);
687        }
688    }
689
690    pub fn get_global_configs_map(&self) -> AgentConfigsMap {
691        let global_configs_map = self.global_configs_map.lock().unwrap();
692        global_configs_map.clone()
693    }
694
695    pub fn set_global_configs_map(&self, new_configs_map: AgentConfigsMap) {
696        for (agent_name, new_configs) in new_configs_map {
697            self.set_global_configs(agent_name, new_configs);
698        }
699    }
700
701    pub async fn agent_input(
702        &self,
703        agent_id: String,
704        ctx: AgentContext,
705        pin: String,
706        value: AgentValue,
707    ) -> Result<(), AgentError> {
708        let agent: Arc<AsyncMutex<Box<dyn Agent>>> = {
709            let agents = self.agents.lock().unwrap();
710            let Some(a) = agents.get(&agent_id) else {
711                return Err(AgentError::AgentNotFound(agent_id.to_string()));
712            };
713            a.clone()
714        };
715
716        let agent_status = {
717            let agent = agent.lock().await;
718            agent.status().clone()
719        };
720        if agent_status != AgentStatus::Start {
721            return Ok(());
722        }
723
724        if pin.starts_with("config:") {
725            let config_key = pin[7..].to_string();
726            let mut agent = agent.lock().await;
727            agent.set_config(config_key.clone(), value.clone())?;
728            return Ok(());
729        }
730
731        let message = AgentMessage::Input {
732            ctx,
733            pin: pin.clone(),
734            value,
735        };
736
737        let tx = {
738            let agent_txs = self.agent_txs.lock().unwrap();
739            let Some(tx) = agent_txs.get(&agent_id) else {
740                return Err(AgentError::AgentTxNotFound(agent_id.to_string()));
741            };
742            tx.clone()
743        };
744        match tx {
745            AgentMessageSender::Sync(tx) => {
746                tx.send(message).map_err(|_| {
747                    AgentError::SendMessageFailed("Failed to send input message".to_string())
748                })?;
749            }
750            AgentMessageSender::Async(tx) => {
751                tx.send(message).await.map_err(|_| {
752                    AgentError::SendMessageFailed("Failed to send input message".to_string())
753                })?;
754            }
755        }
756
757        self.emit_agent_input(agent_id.to_string(), pin);
758
759        Ok(())
760    }
761
762    pub async fn send_agent_out(
763        &self,
764        agent_id: String,
765        ctx: AgentContext,
766        pin: String,
767        value: AgentValue,
768    ) -> Result<(), AgentError> {
769        message::send_agent_out(self, agent_id, ctx, pin, value).await
770    }
771
772    pub fn try_send_agent_out(
773        &self,
774        agent_id: String,
775        ctx: AgentContext,
776        pin: String,
777        value: AgentValue,
778    ) -> Result<(), AgentError> {
779        message::try_send_agent_out(self, agent_id, ctx, pin, value)
780    }
781
782    pub fn write_board_value(&self, name: String, value: AgentValue) -> Result<(), AgentError> {
783        self.try_send_board_out(name, AgentContext::new(), value)
784    }
785
786    pub fn write_var_value(
787        &self,
788        stream_id: &str,
789        name: &str,
790        value: AgentValue,
791    ) -> Result<(), AgentError> {
792        let var_name = format!("%{}/{}", stream_id, name);
793        self.try_send_board_out(var_name, AgentContext::new(), value)
794    }
795
796    pub(crate) fn try_send_board_out(
797        &self,
798        name: String,
799        ctx: AgentContext,
800        value: AgentValue,
801    ) -> Result<(), AgentError> {
802        message::try_send_board_out(self, name, ctx, value)
803    }
804
805    async fn spawn_message_loop(&self) -> Result<(), AgentError> {
806        // TODO: settings for the channel size
807        let (tx, mut rx) = mpsc::channel(4096);
808        {
809            let mut tx_lock = self.tx.lock().unwrap();
810            *tx_lock = Some(tx);
811        }
812
813        // spawn the main loop
814        let askit = self.clone();
815        tokio::spawn(async move {
816            while let Some(message) = rx.recv().await {
817                use AgentEventMessage::*;
818
819                match message {
820                    AgentOut {
821                        agent,
822                        ctx,
823                        pin,
824                        value,
825                    } => {
826                        message::agent_out(&askit, agent, ctx, pin, value).await;
827                    }
828                    BoardOut { name, ctx, value } => {
829                        message::board_out(&askit, name, ctx, value).await;
830                    }
831                }
832            }
833        });
834
835        tokio::task::yield_now().await;
836
837        Ok(())
838    }
839
840    async fn start_agent_streams_on_start(&self) -> Result<(), AgentError> {
841        let run_on_start_stream_ids;
842        {
843            let agent_streams = self.streams.lock().unwrap();
844            run_on_start_stream_ids = agent_streams
845                .values()
846                .filter(|s| s.spec().run_on_start)
847                .map(|s| s.id().to_string())
848                .collect::<Vec<_>>();
849        }
850
851        for id in run_on_start_stream_ids {
852            self.start_agent_stream(&id).await.unwrap_or_else(|e| {
853                log::error!("Failed to start agent stream: {}", e);
854            });
855        }
856        Ok(())
857    }
858
859    pub fn subscribe(&self, observer: Box<dyn ASKitObserver + Sync + Send>) -> usize {
860        let mut observers = self.observers.lock().unwrap();
861        let observer_id = new_observer_id();
862        observers.insert(observer_id, observer);
863        observer_id
864    }
865
866    pub fn unsubscribe(&self, observer_id: usize) {
867        let mut observers = self.observers.lock().unwrap();
868        observers.swap_remove(&observer_id);
869    }
870
871    pub(crate) fn emit_agent_config_updated(
872        &self,
873        agent_id: String,
874        key: String,
875        value: AgentValue,
876    ) {
877        self.notify_observers(ASKitEvent::AgentConfigUpdated(agent_id, key, value));
878    }
879
880    pub(crate) fn emit_agent_error(&self, agent_id: String, message: String) {
881        self.notify_observers(ASKitEvent::AgentError(agent_id, message));
882    }
883
884    pub(crate) fn emit_agent_input(&self, agent_id: String, pin: String) {
885        self.notify_observers(ASKitEvent::AgentIn(agent_id, pin));
886    }
887
888    pub(crate) fn emit_agent_spec_updated(&self, agent_id: String) {
889        self.notify_observers(ASKitEvent::AgentSpecUpdated(agent_id));
890    }
891
892    pub(crate) fn emit_board(&self, name: String, value: AgentValue) {
893        // // ignore variables
894        // if name.starts_with('%') {
895        //     return;
896        // }
897        self.notify_observers(ASKitEvent::Board(name, value));
898    }
899
900    fn notify_observers(&self, event: ASKitEvent) {
901        let observers = self.observers.lock().unwrap();
902        for (_id, observer) in observers.iter() {
903            observer.notify(&event);
904        }
905    }
906}
907
908fn is_valid_stream_name(new_name: &str) -> bool {
909    // Check if the name is empty
910    if new_name.trim().is_empty() {
911        return false;
912    }
913
914    // Checks for path-like names:
915    if new_name.contains('/') {
916        // Disallow leading, trailing, or consecutive slashes
917        if new_name.starts_with('/') || new_name.ends_with('/') || new_name.contains("//") {
918            return false;
919        }
920        // Disallow segments that are "." or ".."
921        if new_name
922            .split('/')
923            .any(|segment| segment == "." || segment == "..")
924        {
925            return false;
926        }
927    }
928
929    // Check if the name contains invalid characters
930    let invalid_chars = ['\\', ':', '*', '?', '"', '<', '>', '|'];
931    for c in invalid_chars {
932        if new_name.contains(c) {
933            return false;
934        }
935    }
936
937    true
938}
939
940#[derive(Clone, Debug)]
941pub enum ASKitEvent {
942    AgentConfigUpdated(String, String, AgentValue), // (agent_id, key, value)
943    AgentError(String, String),                     // (agent_id, message)
944    AgentIn(String, String),                        // (agent_id, pin)
945    AgentSpecUpdated(String),                       // (agent_id)
946    Board(String, AgentValue),                      // (board name, value)
947}
948
949pub trait ASKitObserver {
950    fn notify(&self, event: &ASKitEvent);
951}
952
953static OBSERVER_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
954
955fn new_observer_id() -> usize {
956    OBSERVER_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
957}
958
959// Agent Message
960
961#[derive(Clone)]
962pub enum AgentMessageSender {
963    Sync(std::sync::mpsc::Sender<AgentMessage>),
964    Async(mpsc::Sender<AgentMessage>),
965}