1use std::collections::HashMap;
10use std::sync::atomic::Ordering;
11use std::sync::{Arc, Mutex};
12use std::time::Instant;
13
14use ainl_graph_extractor::GraphExtractorTask;
15use ainl_memory::{
16 AinlMemoryNode, AinlNodeType, GraphMemory, GraphStore, GraphValidationReport, RuntimeStateNode,
17};
18use uuid::Uuid;
19
20use crate::adapters::GraphPatchAdapter;
21use crate::engine::{
22 AinlRuntimeError, MemoryContext, PatchDispatchContext, PatchDispatchResult, PatchSkipReason,
23 TurnInput, TurnOutcome, TurnPhase, TurnResult, TurnStatus, TurnWarning, EMIT_TO_EDGE,
24};
25use super::{
26 compile_persona_from_nodes, emit_target_name, normalize_tools_for_episode,
27 persona_snapshot_if_evolved, procedural_label, record_turn_episode,
28 try_export_graph_json_armaraos,
29};
30
31async fn graph_spawn<T, F>(arc: Arc<Mutex<GraphMemory>>, f: F) -> Result<T, AinlRuntimeError>
32where
33 T: Send + 'static,
34 F: FnOnce(&GraphMemory) -> Result<T, String> + Send + 'static,
35{
36 tokio::task::spawn_blocking(move || {
37 let guard = arc.lock().expect("graph mutex poisoned");
38 f(&guard)
39 })
40 .await
41 .map_err(|e| AinlRuntimeError::AsyncJoinError(e.to_string()))?
42 .map_err(AinlRuntimeError::from)
43}
44
45impl super::AinlRuntime {
46 pub async fn run_turn_async(&mut self, input: TurnInput) -> Result<TurnOutcome, AinlRuntimeError> {
53 let depth = self.current_depth.fetch_add(1, Ordering::SeqCst);
54 let cd = Arc::clone(&self.current_depth);
55 let _depth_guard = scopeguard::guard((), move |()| {
56 cd.fetch_sub(1, Ordering::SeqCst);
57 });
58
59 if depth >= self.config.max_delegation_depth {
60 return Err(AinlRuntimeError::DelegationDepthExceeded {
61 depth,
62 max: self.config.max_delegation_depth,
63 });
64 }
65
66 if let Some(ref hooks_async) = self.hooks_async {
67 hooks_async.on_turn_start(&input).await;
68 }
69
70 if !self.config.enable_graph_memory {
71 let memory_context = MemoryContext::default();
72 let result = TurnResult {
73 memory_context,
74 status: TurnStatus::GraphMemoryDisabled,
75 ..Default::default()
76 };
77 let outcome = TurnOutcome::Complete(result);
78 self.hooks.on_turn_complete(&outcome);
79 if let Some(ref hooks_async) = self.hooks_async {
80 hooks_async.on_turn_complete(&outcome).await;
81 }
82 return Ok(outcome);
83 }
84
85 if self.config.agent_id.is_empty() {
86 return Err(AinlRuntimeError::Message(
87 "RuntimeConfig.agent_id must be set for run_turn".into(),
88 ));
89 }
90
91 let span = tracing::info_span!(
92 "ainl_runtime.run_turn_async",
93 agent_id = %self.config.agent_id,
94 turn = self.turn_count,
95 depth = input.depth,
96 );
97 let _span_enter = span.enter();
98
99 let arc = self.memory.shared_arc();
100 let agent_id = self.config.agent_id.clone();
101
102 let validation: GraphValidationReport = graph_spawn(Arc::clone(&arc), {
103 let agent_id = agent_id.clone();
104 move |m| m.sqlite_store().validate_graph(&agent_id)
105 })
106 .await?;
107
108 if !validation.is_valid {
109 let mut msg = String::from("graph validation failed before turn");
110 for d in &validation.dangling_edge_details {
111 msg.push_str(&format!(
112 "; {} -> {} [{}]",
113 d.source_id, d.target_id, d.edge_type
114 ));
115 }
116 return Err(AinlRuntimeError::Message(msg));
117 }
118
119 self.hooks
120 .on_artifact_loaded(&self.config.agent_id, validation.node_count);
121
122 let mut turn_warnings: Vec<TurnWarning> = Vec::new();
123
124 let t_persona = Instant::now();
125 let persona_prompt_contribution = if let Some(cached) = &self.persona_cache {
126 Some(cached.clone())
127 } else {
128 let nodes = graph_spawn(Arc::clone(&arc), {
129 let agent_id = agent_id.clone();
130 move |m| {
131 m.sqlite_store()
132 .query(&agent_id)
133 .persona_nodes()
134 }
135 })
136 .await?;
137 let compiled = compile_persona_from_nodes(&nodes).map_err(AinlRuntimeError::from)?;
138 self.persona_cache = compiled.clone();
139 compiled
140 };
141 self.hooks
142 .on_persona_compiled(persona_prompt_contribution.as_deref());
143 tracing::debug!(
144 target: "ainl_runtime",
145 duration_ms = t_persona.elapsed().as_millis() as u64,
146 has_contribution = persona_prompt_contribution.is_some(),
147 "persona_phase_async"
148 );
149
150 let t_memory = Instant::now();
151 let (recent_episodes, all_semantic, active_patches) = graph_spawn(Arc::clone(&arc), {
152 let agent_id = agent_id.clone();
153 move |m| {
154 let store = m.sqlite_store();
155 let q = store.query(&agent_id);
156 let recent_episodes = q.recent_episodes(5)?;
157 let all_semantic = q.semantic_nodes()?;
158 let active_patches = q.active_patches()?;
159 Ok((recent_episodes, all_semantic, active_patches))
160 }
161 })
162 .await?;
163
164 let relevant_semantic =
165 self.relevant_semantic_nodes(input.user_message.as_str(), all_semantic, 10);
166 let memory_context = MemoryContext {
167 recent_episodes,
168 relevant_semantic,
169 active_patches,
170 persona_snapshot: persona_snapshot_if_evolved(&self.extractor),
171 compiled_at: chrono::Utc::now(),
172 };
173
174 self.hooks.on_memory_context_ready(&memory_context);
175 tracing::debug!(
176 target: "ainl_runtime",
177 duration_ms = t_memory.elapsed().as_millis() as u64,
178 episode_count = memory_context.recent_episodes.len(),
179 semantic_count = memory_context.relevant_semantic.len(),
180 patch_count = memory_context.active_patches.len(),
181 "memory_context_async"
182 );
183
184 let t_patches = Instant::now();
185 let patch_dispatch_results = if self.config.enable_graph_memory {
186 self.dispatch_patches_collect_async(&memory_context.active_patches, &input.frame, &arc, &mut turn_warnings)
187 .await?
188 } else {
189 Vec::new()
190 };
191 for r in &patch_dispatch_results {
192 tracing::debug!(
193 target: "ainl_runtime",
194 label = %r.label,
195 dispatched = r.dispatched,
196 fitness_before = r.fitness_before,
197 fitness_after = r.fitness_after,
198 "patch_dispatch_async"
199 );
200 }
201 tracing::debug!(
202 target: "ainl_runtime",
203 duration_ms = t_patches.elapsed().as_millis() as u64,
204 "patch_dispatch_phase_async"
205 );
206
207 let dispatched_count = patch_dispatch_results
208 .iter()
209 .filter(|r| r.dispatched)
210 .count() as u32;
211 if dispatched_count >= self.config.max_steps {
212 let result = TurnResult {
213 patch_dispatch_results,
214 memory_context,
215 persona_prompt_contribution,
216 steps_executed: dispatched_count,
217 status: TurnStatus::StepLimitExceeded {
218 steps_executed: dispatched_count,
219 },
220 ..Default::default()
221 };
222 let outcome = TurnOutcome::Complete(result);
223 self.hooks.on_turn_complete(&outcome);
224 if let Some(ref hooks_async) = self.hooks_async {
225 hooks_async.on_turn_complete(&outcome).await;
226 }
227 return Ok(outcome);
228 }
229
230 let t_episode = Instant::now();
231 let tools_canonical = normalize_tools_for_episode(&input.tools_invoked);
232 let tools_for_episode = tools_canonical.clone();
233 let input_clone = input.clone();
234 let episode_id = match graph_spawn(Arc::clone(&arc), {
235 let agent_id = agent_id.clone();
236 move |m| record_turn_episode(m, &agent_id, &input_clone, &tools_for_episode)
237 })
238 .await
239 {
240 Ok(id) => id,
241 Err(e) => {
242 let e = e.message_str().unwrap_or("episode write").to_string();
243 tracing::warn!(
244 phase = ?TurnPhase::EpisodeWrite,
245 error = %e,
246 "non-fatal turn write failed — continuing"
247 );
248 turn_warnings.push(TurnWarning {
249 phase: TurnPhase::EpisodeWrite,
250 error: e,
251 });
252 Uuid::nil()
253 }
254 };
255 self.hooks.on_episode_recorded(episode_id);
256 tracing::debug!(
257 target: "ainl_runtime",
258 duration_ms = t_episode.elapsed().as_millis() as u64,
259 episode_id = %episode_id,
260 "episode_record_async"
261 );
262
263 if !episode_id.is_nil() {
264 for &tid in &input.emit_targets {
265 let eid = episode_id;
266 if let Err(e) = graph_spawn(Arc::clone(&arc), move |m| {
267 m.sqlite_store()
268 .insert_graph_edge_checked(eid, tid, EMIT_TO_EDGE)
269 })
270 .await
271 {
272 let e = e.message_str().unwrap_or("edge").to_string();
273 tracing::warn!(
274 phase = ?TurnPhase::EpisodeWrite,
275 error = %e,
276 "non-fatal turn write failed — continuing"
277 );
278 turn_warnings.push(TurnWarning {
279 phase: TurnPhase::EpisodeWrite,
280 error: e,
281 });
282 }
283 }
284 }
285
286 let emit_payload = serde_json::json!({
287 "episode_id": episode_id.to_string(),
288 "user_message": input.user_message,
289 "tools_invoked": tools_canonical,
290 "persona_contribution": persona_prompt_contribution,
291 "turn_count": self.turn_count.wrapping_add(1),
292 });
293 let emit_neighbors = graph_spawn(Arc::clone(&arc), {
294 let agent_id = agent_id.clone();
295 let eid = episode_id;
296 move |m| {
297 let store = m.sqlite_store();
298 store.query(&agent_id).neighbors(eid, EMIT_TO_EDGE)
299 }
300 })
301 .await;
302 match emit_neighbors {
303 Ok(neighbors) => {
304 for n in neighbors {
305 let target = emit_target_name(&n);
306 self.hooks.on_emit(&target, &emit_payload);
307 }
308 }
309 Err(e) => {
310 let e = e.message_str().unwrap_or("emit").to_string();
311 tracing::warn!(
312 phase = ?TurnPhase::EpisodeWrite,
313 error = %e,
314 "non-fatal turn write failed — continuing"
315 );
316 turn_warnings.push(TurnWarning {
317 phase: TurnPhase::EpisodeWrite,
318 error: format!("emit_routing: {e}"),
319 });
320 }
321 }
322
323 self.turn_count = self.turn_count.wrapping_add(1);
324
325 let should_extract = self.config.extraction_interval > 0
326 && self.turn_count.saturating_sub(self.last_extraction_at_turn)
327 >= self.config.extraction_interval as u64;
328
329 let t_extract = Instant::now();
330 let (extraction_report, _extraction_failed) = if should_extract {
331 let force_fail = std::mem::take(&mut self.test_force_extraction_failure);
332
333 let res = if force_fail {
334 let e = "test_forced".to_string();
335 tracing::warn!(
336 phase = ?TurnPhase::ExtractionPass,
337 error = %e,
338 "non-fatal turn write failed — continuing"
339 );
340 turn_warnings.push(TurnWarning {
341 phase: TurnPhase::ExtractionPass,
342 error: e,
343 });
344 tracing::debug!(
345 target: "ainl_runtime",
346 duration_ms = t_extract.elapsed().as_millis() as u64,
347 signals_ingested = 0u64,
348 skipped = false,
349 "extraction_pass_async"
350 );
351 (None, true)
352 } else {
353 let mem = Arc::clone(&arc);
354 let placeholder = GraphExtractorTask::new(&agent_id);
355 let mut task = std::mem::replace(&mut self.extractor, placeholder);
356 let (task_back, report) = tokio::task::spawn_blocking(move || {
357 let g = mem.lock().expect("graph mutex poisoned");
358 let report = task.run_pass(g.sqlite_store());
359 (task, report)
360 })
361 .await
362 .map_err(|e| AinlRuntimeError::AsyncJoinError(e.to_string()))?;
363 self.extractor = task_back;
364
365 if let Some(ref e) = report.extract_error {
366 tracing::warn!(
367 phase = ?TurnPhase::ExtractionPass,
368 error = %e,
369 "non-fatal turn write failed — continuing"
370 );
371 turn_warnings.push(TurnWarning {
372 phase: TurnPhase::ExtractionPass,
373 error: e.clone(),
374 });
375 }
376 if let Some(ref e) = report.pattern_error {
377 tracing::warn!(
378 phase = ?TurnPhase::PatternPersistence,
379 error = %e,
380 "non-fatal turn write failed — continuing"
381 );
382 turn_warnings.push(TurnWarning {
383 phase: TurnPhase::PatternPersistence,
384 error: e.clone(),
385 });
386 }
387 if let Some(ref e) = report.persona_error {
388 tracing::warn!(
389 phase = ?TurnPhase::PersonaEvolution,
390 error = %e,
391 "non-fatal turn write failed — continuing"
392 );
393 turn_warnings.push(TurnWarning {
394 phase: TurnPhase::PersonaEvolution,
395 error: e.clone(),
396 });
397 }
398 let extraction_failed = report.has_errors();
399 if !extraction_failed {
400 tracing::info!(
401 agent_id = %report.agent_id,
402 signals_extracted = report.signals_extracted,
403 signals_applied = report.signals_applied,
404 semantic_nodes_updated = report.semantic_nodes_updated,
405 "ainl-graph-extractor pass completed (scheduled, async)"
406 );
407 }
408 self.hooks.on_extraction_complete(&report);
409 self.persona_cache = None;
410 tracing::debug!(
411 target: "ainl_runtime",
412 duration_ms = t_extract.elapsed().as_millis() as u64,
413 signals_ingested = report.signals_extracted as u64,
414 skipped = false,
415 "extraction_pass_async"
416 );
417 (Some(report), extraction_failed)
418 };
419 self.last_extraction_at_turn = self.turn_count;
420 res
421 } else {
422 tracing::debug!(
423 target: "ainl_runtime",
424 duration_ms = t_extract.elapsed().as_millis() as u64,
425 signals_ingested = 0u64,
426 skipped = true,
427 "extraction_pass_async"
428 );
429 (None, false)
430 };
431
432 if let Err(e) = graph_spawn(Arc::clone(&arc), {
433 let agent_id = agent_id.clone();
434 move |m| try_export_graph_json_armaraos(m.sqlite_store(), &agent_id)
435 })
436 .await
437 {
438 let e = e.message_str().unwrap_or("export").to_string();
439 tracing::warn!(
440 phase = ?TurnPhase::ExportRefresh,
441 error = %e,
442 "non-fatal turn write failed — continuing"
443 );
444 turn_warnings.push(TurnWarning {
445 phase: TurnPhase::ExportRefresh,
446 error: e,
447 });
448 }
449
450 if !self.config.agent_id.is_empty() {
451 let state = RuntimeStateNode {
452 agent_id: self.config.agent_id.clone(),
453 turn_count: self.turn_count,
454 last_extraction_at_turn: self.last_extraction_at_turn,
455 persona_snapshot_json: self
456 .persona_cache
457 .as_ref()
458 .and_then(|p| serde_json::to_string(p).ok()),
459 updated_at: chrono::Utc::now().timestamp(),
460 };
461 let force_fail = std::mem::take(&mut self.test_force_runtime_state_write_failure);
462 let write_res: Result<(), AinlRuntimeError> = if force_fail {
463 Err(AinlRuntimeError::Message(
464 "injected runtime state write failure".into(),
465 ))
466 } else {
467 graph_spawn(Arc::clone(&arc), move |m| m.write_runtime_state(&state)).await
468 };
469 if let Err(e) = write_res {
470 let e = e.to_string();
471 tracing::warn!(
472 phase = ?TurnPhase::RuntimeStatePersist,
473 error = %e,
474 "failed to persist runtime state — cadence will reset on next restart"
475 );
476 turn_warnings.push(TurnWarning {
477 phase: TurnPhase::RuntimeStatePersist,
478 error: e,
479 });
480 }
481 }
482
483 let result = TurnResult {
484 episode_id,
485 persona_prompt_contribution,
486 memory_context,
487 extraction_report,
488 steps_executed: dispatched_count,
489 patch_dispatch_results,
490 status: TurnStatus::Ok,
491 };
492
493 let outcome = if turn_warnings.is_empty() {
494 TurnOutcome::Complete(result)
495 } else {
496 TurnOutcome::PartialSuccess {
497 result,
498 warnings: turn_warnings,
499 }
500 };
501
502 self.hooks.on_turn_complete(&outcome);
503 if let Some(ref hooks_async) = self.hooks_async {
504 hooks_async.on_turn_complete(&outcome).await;
505 }
506 Ok(outcome)
507 }
508
509 async fn dispatch_patches_collect_async(
510 &mut self,
511 patches: &[AinlMemoryNode],
512 frame: &HashMap<String, serde_json::Value>,
513 arc: &Arc<Mutex<GraphMemory>>,
514 turn_warnings: &mut Vec<TurnWarning>,
515 ) -> Result<Vec<PatchDispatchResult>, AinlRuntimeError> {
516 let mut out = Vec::new();
517 for node in patches {
518 let res = self
519 .dispatch_one_patch_async(node, frame, Arc::clone(arc))
520 .await?;
521 if let Some(PatchSkipReason::PersistFailed(ref e)) = res.skip_reason {
522 tracing::warn!(
523 phase = ?TurnPhase::FitnessWriteBack,
524 error = %e,
525 "non-fatal turn write failed — continuing"
526 );
527 turn_warnings.push(TurnWarning {
528 phase: TurnPhase::FitnessWriteBack,
529 error: format!("{}: {}", res.label, e),
530 });
531 }
532 out.push(res);
533 }
534 Ok(out)
535 }
536
537 async fn dispatch_one_patch_async(
538 &mut self,
539 node: &AinlMemoryNode,
540 frame: &HashMap<String, serde_json::Value>,
541 arc: Arc<Mutex<GraphMemory>>,
542 ) -> Result<PatchDispatchResult, AinlRuntimeError> {
543 let label_default = String::new();
544 let (label_src, pv, retired, reads, fitness_opt) = match &node.node_type {
545 AinlNodeType::Procedural { procedural } => (
546 procedural_label(procedural),
547 procedural.patch_version,
548 procedural.retired,
549 procedural.declared_reads.clone(),
550 procedural.fitness,
551 ),
552 _ => {
553 return Ok(PatchDispatchResult {
554 label: label_default,
555 patch_version: 0,
556 fitness_before: 0.0,
557 fitness_after: 0.0,
558 dispatched: false,
559 skip_reason: Some(PatchSkipReason::NotProcedural),
560 adapter_output: None,
561 adapter_name: None,
562 });
563 }
564 };
565
566 if pv == 0 {
567 return Ok(PatchDispatchResult {
568 label: label_src,
569 patch_version: pv,
570 fitness_before: fitness_opt.unwrap_or(0.5),
571 fitness_after: fitness_opt.unwrap_or(0.5),
572 dispatched: false,
573 skip_reason: Some(PatchSkipReason::ZeroVersion),
574 adapter_output: None,
575 adapter_name: None,
576 });
577 }
578 if retired {
579 return Ok(PatchDispatchResult {
580 label: label_src.clone(),
581 patch_version: pv,
582 fitness_before: fitness_opt.unwrap_or(0.5),
583 fitness_after: fitness_opt.unwrap_or(0.5),
584 dispatched: false,
585 skip_reason: Some(PatchSkipReason::Retired),
586 adapter_output: None,
587 adapter_name: None,
588 });
589 }
590 for key in &reads {
591 if !frame.contains_key(key) {
592 return Ok(PatchDispatchResult {
593 label: label_src.clone(),
594 patch_version: pv,
595 fitness_before: fitness_opt.unwrap_or(0.5),
596 fitness_after: fitness_opt.unwrap_or(0.5),
597 dispatched: false,
598 skip_reason: Some(PatchSkipReason::MissingDeclaredRead(key.clone())),
599 adapter_output: None,
600 adapter_name: None,
601 });
602 }
603 }
604
605 let patch_label = label_src.clone();
606 let adapter_key = patch_label.as_str();
607 let ctx = PatchDispatchContext {
608 patch_label: adapter_key,
609 node,
610 frame,
611 };
612 let (adapter_output, adapter_name) = if let Some(adapter) = self
613 .adapter_registry
614 .get(adapter_key)
615 .or_else(|| self.adapter_registry.get(GraphPatchAdapter::NAME))
616 {
617 let aname = adapter.name().to_string();
618 match adapter.execute_patch(&ctx) {
619 Ok(output) => {
620 tracing::debug!(
621 label = %patch_label,
622 adapter = %aname,
623 "adapter executed patch (async)"
624 );
625 (Some(output), Some(aname))
626 }
627 Err(e) => {
628 tracing::warn!(
629 label = %patch_label,
630 adapter = %aname,
631 error = %e,
632 "adapter execution failed — continuing as metadata dispatch"
633 );
634 (None, Some(aname))
635 }
636 }
637 } else {
638 (None, None)
639 };
640
641 let fitness_before = fitness_opt.unwrap_or(0.5);
642 let fitness_after = 0.2_f32 * 1.0 + 0.8 * fitness_before;
643
644 let nid = node.id;
645 let updated = match graph_spawn(Arc::clone(&arc), move |m| {
646 let store = m.sqlite_store();
647 store.read_node(nid)
648 })
649 .await?
650 {
651 Some(mut n) => {
652 if let AinlNodeType::Procedural { ref mut procedural } = n.node_type {
653 procedural.fitness = Some(fitness_after);
654 }
655 n
656 }
657 None => {
658 return Ok(PatchDispatchResult {
659 label: label_src,
660 patch_version: pv,
661 fitness_before,
662 fitness_after: fitness_before,
663 dispatched: false,
664 skip_reason: Some(PatchSkipReason::MissingDeclaredRead("node_row".into())),
665 adapter_output,
666 adapter_name,
667 });
668 }
669 };
670
671 if self.test_force_fitness_write_failure {
672 self.test_force_fitness_write_failure = false;
673 let e = "injected fitness write failure".to_string();
674 return Ok(PatchDispatchResult {
675 label: label_src.clone(),
676 patch_version: pv,
677 fitness_before,
678 fitness_after: fitness_before,
679 dispatched: false,
680 skip_reason: Some(PatchSkipReason::PersistFailed(e)),
681 adapter_output,
682 adapter_name,
683 });
684 }
685
686 let updated_clone = updated.clone();
687 if let Err(e) = graph_spawn(arc, move |m| m.write_node(&updated_clone)).await {
688 return Ok(PatchDispatchResult {
689 label: label_src.clone(),
690 patch_version: pv,
691 fitness_before,
692 fitness_after: fitness_before,
693 dispatched: false,
694 skip_reason: Some(PatchSkipReason::PersistFailed(
695 e.message_str().unwrap_or("write").to_string(),
696 )),
697 adapter_output,
698 adapter_name,
699 });
700 }
701
702 self.hooks
703 .on_patch_dispatched(label_src.as_str(), fitness_after);
704 if let Some(ref hooks_async) = self.hooks_async {
705 let hook_ctx = PatchDispatchContext {
706 patch_label: adapter_key,
707 node,
708 frame,
709 };
710 let _ = hooks_async.on_patch_dispatched(&hook_ctx).await;
711 }
712
713 Ok(PatchDispatchResult {
714 label: label_src,
715 patch_version: pv,
716 fitness_before,
717 fitness_after,
718 dispatched: true,
719 skip_reason: None,
720 adapter_output,
721 adapter_name,
722 })
723 }
724}