1use std::collections::{HashMap, HashSet};
2use stepflow_base::{ObjectStore, ObjectStoreContent, ObjectStoreFiltered, IdError, generate_id_type};
3use stepflow_data::{StateData, StateDataFiltered, var::{Var, VarId}, value::Value};
4use stepflow_step::{Step, StepId};
5use stepflow_action::{Action, ActionResult, ActionId};
6use super::{Error, dfs};
7
8
9generate_id_type!(SessionId);
10
11
12#[derive(Debug)]
44pub struct Session {
45 id: SessionId,
46 state_data: StateData,
47 actions: HashMap<StepId, ActionId>,
48
49 step_store: ObjectStore<Step, StepId>,
50 action_store: ObjectStore<Box<dyn Action + Sync + Send>, ActionId>,
51 var_store: ObjectStore<Box<dyn Var + Send + Sync>, VarId>,
52
53 step_id_all: StepId,
54 step_id_root: StepId,
55
56 step_id_dfs: dfs::DepthFirstSearch,
57}
58
59impl ObjectStoreContent for Session {
60 type IdType = SessionId;
61
62 fn new_id(id_val: u16) -> Self::IdType {
63 SessionId::new(id_val)
64 }
65
66 fn id(&self) -> &Self::IdType {
67 &self.id
68 }
69}
70
71impl Session {
72 pub fn new(id: SessionId) -> Self {
74 Self::with_capacity(id, 0, 0, 0)
75 }
76
77 pub fn with_capacity(id: SessionId, var_capacity: usize, step_capacity: usize, action_capacity: usize) -> Self {
79 let mut step_store = ObjectStore::with_capacity(step_capacity);
81
82 let step_id_all = step_store.insert_new_named(
84 "STEP_ID_ACTION_ALL",
85 |id| Ok(Step::new(id, None, vec![]))).unwrap();
86
87 let step_id_root = step_store.insert_new_named(
90 "SESSION_ROOT",
91 |id| Ok(Step::new(id, None, vec![]))).unwrap();
92
93 Session {
94 id,
95 state_data: StateData::new(),
96 actions: HashMap::new(),
97 step_store,
98 action_store: ObjectStore::with_capacity(action_capacity),
99 var_store: ObjectStore::with_capacity(var_capacity),
100 step_id_all: step_id_all,
101 step_id_root: step_id_root,
102 step_id_dfs: dfs::DepthFirstSearch::new(step_id_root),
103 }
104 }
105
106 pub fn id(&self) -> &SessionId {
108 &self.id
109 }
110
111 pub fn state_data(&self) -> &StateData {
113 &self.state_data
114 }
115
116 pub fn current_step(&self) -> Result<&StepId, Error> {
117 self.step_id_dfs.current().ok_or_else(|| Error::NoStateToEval)
118 }
119
120 pub fn step_store(&self) -> &ObjectStore<Step, StepId> {
122 &self.step_store
123 }
124
125 pub fn step_store_mut(&mut self) -> &mut ObjectStore<Step, StepId> {
127 &mut self.step_store
128 }
129
130 pub fn push_root_substep(&mut self, step_id: StepId) {
132 let root_step = self.step_store.get_mut(&self.step_id_root).unwrap();
133 root_step.push_substep(step_id);
134 }
135
136 pub fn action_store(&self) -> &ObjectStore<Box<dyn Action + Sync + Send>, ActionId> {
138 &self.action_store
139 }
140
141 pub fn action_store_mut(&mut self) -> &mut ObjectStore<Box<dyn Action + Sync + Send>, ActionId> {
142 &mut self.action_store
143 }
144
145 pub fn var_store(&self) -> &ObjectStore<Box<dyn Var + Sync + Send>, VarId> {
147 &self.var_store
148 }
149
150 pub fn var_store_mut(&mut self) -> &mut ObjectStore<Box<dyn Var + Sync + Send>, VarId> {
152 &mut self.var_store
153 }
154
155 pub fn set_action_for_step(&mut self, action_id: ActionId, step_id:Option<&StepId>)
161 -> Result<(), Error> {
162 let step_id_use = step_id.or(Some(&self.step_id_all)).unwrap();
163 if self.actions.contains_key(step_id_use) {
164 return Err(Error::StepId(IdError::IdAlreadyExists(step_id_use.clone())));
165 }
166 self.actions.insert(step_id_use.clone(), action_id);
167 Ok(())
168 }
169
170
171 fn try_enter_next_step(&mut self, step_output: Option<(&StepId, StateData)>)
175 -> Result<Option<StepId>, Error>
176 {
177 if let Some(output) = step_output {
178 if self.current_step()? != output.0 {
180 return Err(Error::StepId(IdError::IdUnexpected(output.0.clone())))
181 }
182
183 self.state_data.merge_from(output.1)
185 }
186
187 let state_data = &self.state_data;
188 let step_store = &self.step_store;
189 self.step_id_dfs.next(
190 |step_id| {
191 let step = step_store.get(step_id).ok_or_else(|| Error::StepId(IdError::IdMissing(step_id.clone())))?;
192 step.can_enter(&state_data).map_err(|e| Error::VarId(e))
193 },
194 |step_id| {
195 let step = step_store.get(step_id).ok_or_else(|| Error::StepId(IdError::IdMissing(step_id.clone())))?;
196 step.can_exit(&state_data).map_err(|e| Error::VarId(e))
197 },
198 &self.step_store)
199 }
200
201 fn call_action(&mut self, action_id: &ActionId, step_id: &StepId) -> Result<ActionResult, Error> {
202 fn get_step_input_output_vars(step: &Step) -> HashSet<VarId> {
204 step.get_input_vars()
205 .clone()
206 .unwrap_or_else(|| vec![])
207 .iter()
208 .chain(step.get_output_vars().iter())
209 .map(|id_ref| id_ref.clone())
210 .collect::<HashSet<VarId>>()
211 }
212
213 let step = self.step_store.get(step_id).ok_or_else(|| Error::StepId(IdError::IdMissing(step_id.clone())))?;
214 let step_name = self.step_store.name_from_id(&step_id);
215 let step_data: StateDataFiltered = StateDataFiltered::new(&self.state_data, get_step_input_output_vars(&step));
216 let vars = ObjectStoreFiltered::new(&self.var_store, get_step_input_output_vars(&step));
217
218 let action = self.action_store.get_mut(action_id).ok_or_else(|| Error::ActionId(IdError::IdMissing(action_id.clone())))?;
220 let action_result = action.start(&step, step_name, &step_data, &vars).map_err(|e| Error::from(e))?;
221 match &action_result {
222 ActionResult::Finished(state_data) => {
223 if !state_data.contains_only(&step.output_vars.iter().collect::<HashSet<_>>()) {
224 return Err(Error::InvalidStateDataError);
225 }
226 }
227 ActionResult::StartWith(_) |
228 ActionResult::CannotFulfill => ()
229 }
230 Ok(action_result)
231 }
232
233 pub fn advance(&mut self, step_output: Option<(&StepId, StateData)>)
245 -> Result<AdvanceBlockedOn, Error>
246 {
247 #[derive(Clone, Debug)]
248 enum States {
249 AdvanceStep,
250 GetSpecificAction(StepId, Option<Error>), GetGenericAction(StepId, Option<Error>), StartSpecific(ActionId, StepId, Option<Error>), StartGeneric(ActionId, StepId, Option<Error>), Done(Result<AdvanceBlockedOn, Error>)
255 }
256
257 let mut step_output = step_output;
265 let mut state = States::AdvanceStep;
266 loop {
267 state = match state.clone() {
268 States::Done(result) => return result,
269 States::AdvanceStep => {
270 let advance_result = self.try_enter_next_step(step_output);
271 step_output = None;
272 match &advance_result {
273 Ok(step_id_opt) => {
274 match step_id_opt {
275 Some(step_id) => States::GetSpecificAction(step_id.clone(), None),
276 None => States::Done(Ok(AdvanceBlockedOn::FinishedAdvancing)), }
278 }
279 Err(err) => {
280 let step_id = self.current_step()?.clone();
281 States::GetSpecificAction(step_id, Some(err.clone())) }
283 }
284 },
285 States::GetSpecificAction(step_id, error) => {
286 match self.actions.get(&step_id) {
287 Some(action_id) => States::StartSpecific(action_id.clone(), step_id, error),
288 None => States::GetGenericAction(step_id, error),
289 }
290 },
291 States::GetGenericAction(step_id, error) => {
292 match self.actions.get(&self.step_id_all) {
293 Some(action_id) => States::StartGeneric(action_id.clone(), step_id, error),
294 None => {
295 match error {
296 None => States::AdvanceStep, Some(err) => return Err(err), }
299 }
300 }
301 },
302 States::StartSpecific(action_id, step_id, error_opt) |
303 States::StartGeneric(action_id, step_id, error_opt) => {
304 let action_result = self.call_action(&action_id, &step_id)?;
305 match action_result {
306 ActionResult::StartWith(val) => {
307 States::Done(Ok(AdvanceBlockedOn::ActionStartWith(action_id, val)))
308 }
309 ActionResult::Finished(state_data) => {
310 self.state_data.merge_from(state_data.clone());
312 States::AdvanceStep
313 }
314 ActionResult::CannotFulfill => {
315 if matches!(state, States::StartSpecific(_,_,_)) {
316 States::GetGenericAction(step_id, error_opt)
318 } else {
319 States::Done(Ok(AdvanceBlockedOn::ActionCannotFulfill))
321 }
322 }
323 }
324 }
325 }
326 }
327 }
328
329 #[cfg(test)]
330 pub fn test_new() -> (Session, StepId) {
331 let mut session = Session::new(stepflow_test_util::test_id!(SessionId));
332 let root_step_id = session.step_store_mut().insert_new_named("root_step", |id| Ok(Step::new(id, None, vec![]))).unwrap();
333 session.push_root_substep(root_step_id.clone());
334 (session, root_step_id)
335 }
336
337 #[cfg(test)]
338 pub fn test_new_stringvar(&mut self) -> VarId {
339 let var_id = stepflow_test_util::test_id!(VarId);
340 let var = stepflow_data::var::StringVar::new(var_id);
341 let var_id = self.var_store.register( var.boxed()).unwrap();
342 var_id
343 }
344}
345
346#[derive(Debug, Clone)]
348pub enum AdvanceBlockedOn {
349 ActionStartWith(ActionId, Box<dyn Value>),
351
352 ActionCannotFulfill,
354
355 FinishedAdvancing,
357}
358
359impl PartialEq for AdvanceBlockedOn {
360 fn eq(&self, other: &Self) -> bool {
361 match (self, other) {
362 (AdvanceBlockedOn::ActionStartWith(action_id, val),AdvanceBlockedOn::ActionStartWith(action_id_other, val_other)) => {
363 action_id == action_id_other && val == val_other
364 }
365 (AdvanceBlockedOn::ActionCannotFulfill, AdvanceBlockedOn::ActionCannotFulfill) |
366 (AdvanceBlockedOn::FinishedAdvancing, AdvanceBlockedOn::FinishedAdvancing) => {
367 true
368 }
369 _ => false
370 }
371 }
372}
373
374
375#[cfg(test)]
376mod tests {
377 use core::panic;
378 use stepflow_base::{ObjectStore, IdError};
379 use stepflow_data::{StateData, var::VarId, value::{BoolValue, StringValue}};
380 use stepflow_step::{Step, StepId};
381 use stepflow_test_util::test_id;
382 use stepflow_action::{SetDataAction, ActionId};
383 use crate::test::TestAction;
384 use super::super::{Error};
385 use super::{Session, SessionId, AdvanceBlockedOn};
386
387
388
389 fn new_simple_step(id: StepId) -> Result<Step, IdError<StepId>> {
390 Ok(Step::new(id, None, vec![]))
391 }
392
393 fn add_new_simple_substep(parent_id: &StepId, step_store: &mut ObjectStore<Step, StepId>) -> StepId {
394 let substep_id = step_store.insert_new(new_simple_step).unwrap();
395 push_substep(parent_id, substep_id, step_store)
396 }
397
398 fn push_substep(parent_id: &StepId, step_id: StepId, step_store: &mut ObjectStore<Step, StepId>) -> StepId {
399 let parent = step_store.get_mut(parent_id).unwrap();
400 parent.push_substep(step_id.clone());
401 step_id
402 }
403
404 fn step_str_output(session: &Session, var_id: &VarId, val: &'static str) -> (StepId, StateData) {
405 let mut state_data = StateData::new();
406 let var = session.var_store().get(var_id).unwrap();
407 state_data.insert(var, StringValue::try_new(val).unwrap().boxed()).unwrap();
408 (session.current_step().unwrap().clone(), state_data)
409 }
410
411 #[test]
412 fn empty_session_advance() {
413 let mut session = Session::new(test_id!(SessionId));
414 let advance_result = session.advance(None);
415 assert_eq!(advance_result, Ok(AdvanceBlockedOn::FinishedAdvancing));
416 }
417
418 #[test]
419 fn progress_session_inputs_outputs() {
420 let mut session = Session::new(test_id!(SessionId));
421
422 let var_output1_id = session.test_new_stringvar();
423 let var_input2_id = session.test_new_stringvar();
424 let var_output2_id = session.test_new_stringvar();
425
426 let root_step_id = session.step_store.insert_new_named(
427 "root_step", |id| {
428 Ok(Step::new(
429 id,
430 Some(vec![var_input2_id.clone()]),
431 vec![var_output1_id.clone(), var_output2_id.clone()]))
432 })
433 .unwrap();
434 session.push_root_substep(root_step_id);
435
436 let substep1_id = session.step_store_mut().insert_new_named("SubStep 1",
437 |id| Ok(Step::new(id, None, vec![var_output1_id.clone()])))
438 .unwrap();
439 let substep2_id = session.step_store_mut().insert_new_named("SubStep 2",
440 |id| Ok(Step::new(id, Some(vec![var_input2_id.clone()]), vec![var_output2_id.clone()])))
441 .unwrap();
442
443 let root_step = session.step_store_mut().get_mut(&root_step_id).unwrap();
444 root_step.push_substep(substep1_id.clone());
445 root_step.push_substep(substep2_id.clone());
446
447 assert_eq!(session.try_enter_next_step(None), Err(Error::VarId(IdError::IdMissing(var_input2_id.clone())))); let output1 = step_str_output(&session, &var_input2_id, "input2");
451 assert_eq!(session.try_enter_next_step(Some((&output1.0, output1.1))), Ok(Some(substep1_id.clone()))); assert_eq!(session.try_enter_next_step(None), Err(Error::VarId(IdError::IdMissing(var_output1_id.clone())))); let output2 = step_str_output(&session, &var_output1_id, "output1");
456 assert_eq!(session.try_enter_next_step(Some((&output2.0, output2.1))), Ok(Some(substep2_id.clone())));
457
458 assert_eq!(session.try_enter_next_step(None), Err(Error::VarId(IdError::IdMissing(var_output2_id.clone()))));
460 let output3 = step_str_output(&session, &var_output2_id, "output2");
461 assert_eq!(session.try_enter_next_step(Some((&output3.0, output3.1))), Ok(None));
462
463 assert_eq!(session.try_enter_next_step(None), Ok(None));
465 }
466
467 #[test]
468 fn simple_action() {
469 let (mut session, root_step_id) = Session::test_new();
470
471 let substep1 = add_new_simple_substep(&root_step_id, session.step_store_mut());
472 let substep2 = add_new_simple_substep(&root_step_id, session.step_store_mut());
473 let substep3 = add_new_simple_substep(&root_step_id, session.step_store_mut());
474
475 let test_action_id = session.action_store_mut().insert_new(
476 |id| Ok(TestAction::new_with_id(id, true).boxed()))
477 .unwrap();
478 session.set_action_for_step(test_action_id, None).unwrap();
479
480 let mut steps_executed:Vec<StepId> = vec![];
481 loop {
482 match session.advance(None) {
483 Ok(advance_result) => {
484 match advance_result {
485 AdvanceBlockedOn::ActionStartWith(_, _) => (),
486 AdvanceBlockedOn::FinishedAdvancing => break,
487 _ => panic!("Unexpected advance result: {:?}", advance_result),
488 }
489 },
490 Err(err) => {
491 panic!("unexpected error trying to advance: {:?}", err);
492 },
493 }
494 steps_executed.push(session.current_step().unwrap().clone());
495 }
496
497 assert_eq!(steps_executed, vec![substep1, substep2, substep3]);
499 }
500
501
502 #[test]
503 fn specific_generic_actions() {
504
505 let (mut session, root_step_id) = Session::test_new();
507 let var_id = session.test_new_stringvar();
508
509 let substep1 = session.step_store_mut().insert_new(|id| {
510 Ok(Step::new(id, None, vec![var_id.clone()]))
511 })
512 .unwrap();
513 push_substep(&root_step_id, substep1.clone(), session.step_store_mut());
514
515 let substep2 = session.step_store_mut().insert_new(
516 |id| Ok(Step::new(id, Some(vec![var_id.clone()]), vec![var_id.clone()])))
517 .unwrap();
518 push_substep(&root_step_id, substep2.clone(), session.step_store_mut());
519
520 let mut statedata_exec = StateData::new();
522 let var = session.var_store().get(&var_id).unwrap();
523 statedata_exec.insert(var, StringValue::try_new("hi").unwrap().boxed()).unwrap();
524
525 let set_action_id = session.action_store_mut().insert_new(|id| {
527 Ok(SetDataAction::new(id, statedata_exec, 2).boxed())
528 }).unwrap();
529
530 let test_action_id = session.action_store_mut().insert_new(|id| {
531 Ok(TestAction::new_with_id(id, true).boxed())
532 })
533 .unwrap();
534
535 session.set_action_for_step(set_action_id, Some(&substep1)).unwrap();
537 session.set_action_for_step(test_action_id, None).unwrap();
538
539 if let AdvanceBlockedOn::ActionStartWith(_, _) = session.advance(None).unwrap() {
541 assert_eq!(*session.current_step().unwrap(), substep1.clone()); } else {
543 panic!("did not advance");
544 }
545
546 if let AdvanceBlockedOn::ActionStartWith(_, _) = session.advance(None).unwrap() {
548 assert!(!session.state_data.contains(&var_id)); } else {
550 panic!("did not advance");
551 }
552
553 if let AdvanceBlockedOn::ActionStartWith(_, _) = session.advance(None).unwrap() {
555 assert_eq!(*session.current_step().unwrap(), substep2.clone()); assert!(session.state_data.contains(&var_id)); } else {
558 panic!("did not advance");
559 }
560
561 assert_eq!(
563 session.advance(None).unwrap(),
564 AdvanceBlockedOn::FinishedAdvancing);
565 }
566
567 #[test]
568 fn auto_advance() {
569 let (mut session, root_step_id) = Session::test_new();
570 let test_action_id = session.action_store_mut().insert_new(|id| {
571 Ok(TestAction::new_with_id(id, false).boxed())
572 })
573 .unwrap();
574
575 let _substep1 = add_new_simple_substep(&root_step_id, session.step_store_mut());
576 let _substep2 = add_new_simple_substep(&root_step_id, session.step_store_mut());
577 let _substep3 = add_new_simple_substep(&root_step_id, session.step_store_mut());
578
579 session.set_action_for_step(test_action_id, None).unwrap();
580
581 let advance = session.advance(None);
583 assert_eq!(advance, Ok(AdvanceBlockedOn::FinishedAdvancing));
584 }
585
586 #[test]
587 fn advance_blocked_on_eq() {
588 let abo_finish = AdvanceBlockedOn::FinishedAdvancing;
589 assert_eq!(abo_finish, abo_finish);
590
591 let abo_cannot_fulfill = AdvanceBlockedOn::ActionCannotFulfill;
592 assert_ne!(abo_finish, abo_cannot_fulfill);
593
594 let action_id = test_id!(ActionId);
595 let abo_start_true = AdvanceBlockedOn::ActionStartWith(action_id.clone(), BoolValue::new(true).boxed());
596 let abo_start_false = AdvanceBlockedOn::ActionStartWith(action_id, BoolValue::new(false).boxed());
597 assert_eq!(abo_start_false, abo_start_false);
598 assert_ne!(abo_start_true, abo_start_false);
599 assert_ne!(abo_start_false, abo_finish);
600 }
601
602}
603