1use std::collections::HashMap;
10use std::sync::atomic::Ordering;
11use std::sync::{Arc, Mutex};
12use std::time::Instant;
13
14use ainl_graph_extractor::GraphExtractorTask;
15use ainl_memory::{
16 AinlMemoryNode, AinlNodeType, GraphMemory, GraphStore, GraphValidationReport, RuntimeStateNode,
17};
18use uuid::Uuid;
19
20use super::{
21 compile_persona_from_nodes, emit_target_name, normalize_tools_for_episode,
22 persona_snapshot_if_evolved, procedural_label, record_turn_episode,
23 try_export_graph_json_armaraos,
24};
25use crate::adapters::GraphPatchAdapter;
26use crate::engine::{
27 AinlRuntimeError, MemoryContext, PatchDispatchContext, PatchDispatchResult, PatchSkipReason,
28 TurnInput, TurnOutcome, TurnPhase, TurnResult, TurnStatus, TurnWarning, EMIT_TO_EDGE,
29};
30
31async fn graph_spawn<T, F>(arc: Arc<Mutex<GraphMemory>>, f: F) -> Result<T, AinlRuntimeError>
32where
33 T: Send + 'static,
34 F: FnOnce(&GraphMemory) -> Result<T, String> + Send + 'static,
35{
36 tokio::task::spawn_blocking(move || {
37 let guard = arc.lock().expect("graph mutex poisoned");
38 f(&guard)
39 })
40 .await
41 .map_err(|e| AinlRuntimeError::AsyncJoinError(e.to_string()))?
42 .map_err(AinlRuntimeError::from)
43}
44
45impl super::AinlRuntime {
46 pub async fn run_turn_async(
53 &mut self,
54 input: TurnInput,
55 ) -> Result<TurnOutcome, AinlRuntimeError> {
56 let depth = self.current_depth.fetch_add(1, Ordering::SeqCst);
57 let cd = Arc::clone(&self.current_depth);
58 let _depth_guard = scopeguard::guard((), move |()| {
59 cd.fetch_sub(1, Ordering::SeqCst);
60 });
61
62 if depth >= self.config.max_delegation_depth {
63 return Err(AinlRuntimeError::DelegationDepthExceeded {
64 depth,
65 max: self.config.max_delegation_depth,
66 });
67 }
68
69 if let Some(ref hooks_async) = self.hooks_async {
70 hooks_async.on_turn_start(&input).await;
71 }
72
73 if !self.config.enable_graph_memory {
74 let memory_context = MemoryContext::default();
75 let result = TurnResult {
76 memory_context,
77 status: TurnStatus::GraphMemoryDisabled,
78 ..Default::default()
79 };
80 let outcome = TurnOutcome::Complete(result);
81 self.hooks.on_turn_complete(&outcome);
82 if let Some(ref hooks_async) = self.hooks_async {
83 hooks_async.on_turn_complete(&outcome).await;
84 }
85 return Ok(outcome);
86 }
87
88 if self.config.agent_id.is_empty() {
89 return Err(AinlRuntimeError::Message(
90 "RuntimeConfig.agent_id must be set for run_turn".into(),
91 ));
92 }
93
94 let span = tracing::info_span!(
95 "ainl_runtime.run_turn_async",
96 agent_id = %self.config.agent_id,
97 turn = self.turn_count,
98 depth = input.depth,
99 );
100 let _span_enter = span.enter();
101
102 let arc = self.memory.shared_arc();
103 let agent_id = self.config.agent_id.clone();
104
105 let validation: GraphValidationReport = graph_spawn(Arc::clone(&arc), {
106 let agent_id = agent_id.clone();
107 move |m| m.sqlite_store().validate_graph(&agent_id)
108 })
109 .await?;
110
111 if !validation.is_valid {
112 let mut msg = String::from("graph validation failed before turn");
113 for d in &validation.dangling_edge_details {
114 msg.push_str(&format!(
115 "; {} -> {} [{}]",
116 d.source_id, d.target_id, d.edge_type
117 ));
118 }
119 return Err(AinlRuntimeError::Message(msg));
120 }
121
122 self.hooks
123 .on_artifact_loaded(&self.config.agent_id, validation.node_count);
124
125 let mut turn_warnings: Vec<TurnWarning> = Vec::new();
126
127 let t_persona = Instant::now();
128 let persona_prompt_contribution = if let Some(cached) = &self.persona_cache {
129 Some(cached.clone())
130 } else {
131 let nodes = graph_spawn(Arc::clone(&arc), {
132 let agent_id = agent_id.clone();
133 move |m| m.sqlite_store().query(&agent_id).persona_nodes()
134 })
135 .await?;
136 let compiled = compile_persona_from_nodes(&nodes).map_err(AinlRuntimeError::from)?;
137 self.persona_cache = compiled.clone();
138 compiled
139 };
140 self.hooks
141 .on_persona_compiled(persona_prompt_contribution.as_deref());
142 tracing::debug!(
143 target: "ainl_runtime",
144 duration_ms = t_persona.elapsed().as_millis() as u64,
145 has_contribution = persona_prompt_contribution.is_some(),
146 "persona_phase_async"
147 );
148
149 let t_memory = Instant::now();
150 let (recent_episodes, all_semantic, active_patches) = graph_spawn(Arc::clone(&arc), {
151 let agent_id = agent_id.clone();
152 move |m| {
153 let store = m.sqlite_store();
154 let q = store.query(&agent_id);
155 let recent_episodes = q.recent_episodes(5)?;
156 let all_semantic = q.semantic_nodes()?;
157 let active_patches = q.active_patches()?;
158 Ok((recent_episodes, all_semantic, active_patches))
159 }
160 })
161 .await?;
162
163 let relevant_semantic =
164 self.relevant_semantic_nodes(input.user_message.as_str(), all_semantic, 10);
165 let memory_context = MemoryContext {
166 recent_episodes,
167 relevant_semantic,
168 active_patches,
169 persona_snapshot: persona_snapshot_if_evolved(&self.extractor),
170 compiled_at: chrono::Utc::now(),
171 };
172
173 self.hooks.on_memory_context_ready(&memory_context);
174 tracing::debug!(
175 target: "ainl_runtime",
176 duration_ms = t_memory.elapsed().as_millis() as u64,
177 episode_count = memory_context.recent_episodes.len(),
178 semantic_count = memory_context.relevant_semantic.len(),
179 patch_count = memory_context.active_patches.len(),
180 "memory_context_async"
181 );
182
183 let t_patches = Instant::now();
184 let patch_dispatch_results = if self.config.enable_graph_memory {
185 self.dispatch_patches_collect_async(
186 &memory_context.active_patches,
187 &input.frame,
188 &arc,
189 &mut turn_warnings,
190 )
191 .await?
192 } else {
193 Vec::new()
194 };
195 for r in &patch_dispatch_results {
196 tracing::debug!(
197 target: "ainl_runtime",
198 label = %r.label,
199 dispatched = r.dispatched,
200 fitness_before = r.fitness_before,
201 fitness_after = r.fitness_after,
202 "patch_dispatch_async"
203 );
204 }
205 tracing::debug!(
206 target: "ainl_runtime",
207 duration_ms = t_patches.elapsed().as_millis() as u64,
208 "patch_dispatch_phase_async"
209 );
210
211 let dispatched_count = patch_dispatch_results
212 .iter()
213 .filter(|r| r.dispatched)
214 .count() as u32;
215 if dispatched_count >= self.config.max_steps {
216 let result = TurnResult {
217 patch_dispatch_results,
218 memory_context,
219 persona_prompt_contribution,
220 steps_executed: dispatched_count,
221 status: TurnStatus::StepLimitExceeded {
222 steps_executed: dispatched_count,
223 },
224 ..Default::default()
225 };
226 let outcome = TurnOutcome::Complete(result);
227 self.hooks.on_turn_complete(&outcome);
228 if let Some(ref hooks_async) = self.hooks_async {
229 hooks_async.on_turn_complete(&outcome).await;
230 }
231 return Ok(outcome);
232 }
233
234 let t_episode = Instant::now();
235 let tools_canonical = normalize_tools_for_episode(&input.tools_invoked);
236 let tools_for_episode = tools_canonical.clone();
237 let input_clone = input.clone();
238 let episode_id = match graph_spawn(Arc::clone(&arc), {
239 let agent_id = agent_id.clone();
240 move |m| record_turn_episode(m, &agent_id, &input_clone, &tools_for_episode)
241 })
242 .await
243 {
244 Ok(id) => id,
245 Err(e) => {
246 let e = e.message_str().unwrap_or("episode write").to_string();
247 tracing::warn!(
248 phase = ?TurnPhase::EpisodeWrite,
249 error = %e,
250 "non-fatal turn write failed — continuing"
251 );
252 turn_warnings.push(TurnWarning {
253 phase: TurnPhase::EpisodeWrite,
254 error: e,
255 });
256 Uuid::nil()
257 }
258 };
259 self.hooks.on_episode_recorded(episode_id);
260 tracing::debug!(
261 target: "ainl_runtime",
262 duration_ms = t_episode.elapsed().as_millis() as u64,
263 episode_id = %episode_id,
264 "episode_record_async"
265 );
266
267 if !episode_id.is_nil() {
268 for &tid in &input.emit_targets {
269 let eid = episode_id;
270 if let Err(e) = graph_spawn(Arc::clone(&arc), move |m| {
271 m.sqlite_store()
272 .insert_graph_edge_checked(eid, tid, EMIT_TO_EDGE)
273 })
274 .await
275 {
276 let e = e.message_str().unwrap_or("edge").to_string();
277 tracing::warn!(
278 phase = ?TurnPhase::EpisodeWrite,
279 error = %e,
280 "non-fatal turn write failed — continuing"
281 );
282 turn_warnings.push(TurnWarning {
283 phase: TurnPhase::EpisodeWrite,
284 error: e,
285 });
286 }
287 }
288 }
289
290 let emit_payload = serde_json::json!({
291 "episode_id": episode_id.to_string(),
292 "user_message": input.user_message,
293 "tools_invoked": tools_canonical,
294 "persona_contribution": persona_prompt_contribution,
295 "turn_count": self.turn_count.wrapping_add(1),
296 });
297 let emit_neighbors = graph_spawn(Arc::clone(&arc), {
298 let agent_id = agent_id.clone();
299 let eid = episode_id;
300 move |m| {
301 let store = m.sqlite_store();
302 store.query(&agent_id).neighbors(eid, EMIT_TO_EDGE)
303 }
304 })
305 .await;
306 match emit_neighbors {
307 Ok(neighbors) => {
308 for n in neighbors {
309 let target = emit_target_name(&n);
310 self.hooks.on_emit(&target, &emit_payload);
311 }
312 }
313 Err(e) => {
314 let e = e.message_str().unwrap_or("emit").to_string();
315 tracing::warn!(
316 phase = ?TurnPhase::EpisodeWrite,
317 error = %e,
318 "non-fatal turn write failed — continuing"
319 );
320 turn_warnings.push(TurnWarning {
321 phase: TurnPhase::EpisodeWrite,
322 error: format!("emit_routing: {e}"),
323 });
324 }
325 }
326
327 self.turn_count = self.turn_count.wrapping_add(1);
328
329 let should_extract = self.config.extraction_interval > 0
330 && self.turn_count.saturating_sub(self.last_extraction_at_turn)
331 >= self.config.extraction_interval as u64;
332
333 let t_extract = Instant::now();
334 let (extraction_report, _extraction_failed) = if should_extract {
335 let force_fail = std::mem::take(&mut self.test_force_extraction_failure);
336
337 let res = if force_fail {
338 let e = "test_forced".to_string();
339 tracing::warn!(
340 phase = ?TurnPhase::ExtractionPass,
341 error = %e,
342 "non-fatal turn write failed — continuing"
343 );
344 turn_warnings.push(TurnWarning {
345 phase: TurnPhase::ExtractionPass,
346 error: e,
347 });
348 tracing::debug!(
349 target: "ainl_runtime",
350 duration_ms = t_extract.elapsed().as_millis() as u64,
351 signals_ingested = 0u64,
352 skipped = false,
353 "extraction_pass_async"
354 );
355 (None, true)
356 } else {
357 let mem = Arc::clone(&arc);
358 let placeholder = GraphExtractorTask::new(&agent_id);
359 let mut task = std::mem::replace(&mut self.extractor, placeholder);
360 let (task_back, report) = tokio::task::spawn_blocking(move || {
361 let g = mem.lock().expect("graph mutex poisoned");
362 let report = task.run_pass(g.sqlite_store());
363 (task, report)
364 })
365 .await
366 .map_err(|e| AinlRuntimeError::AsyncJoinError(e.to_string()))?;
367 self.extractor = task_back;
368
369 if let Some(ref e) = report.extract_error {
370 tracing::warn!(
371 phase = ?TurnPhase::ExtractionPass,
372 error = %e,
373 "non-fatal turn write failed — continuing"
374 );
375 turn_warnings.push(TurnWarning {
376 phase: TurnPhase::ExtractionPass,
377 error: e.clone(),
378 });
379 }
380 if let Some(ref e) = report.pattern_error {
381 tracing::warn!(
382 phase = ?TurnPhase::PatternPersistence,
383 error = %e,
384 "non-fatal turn write failed — continuing"
385 );
386 turn_warnings.push(TurnWarning {
387 phase: TurnPhase::PatternPersistence,
388 error: e.clone(),
389 });
390 }
391 if let Some(ref e) = report.persona_error {
392 tracing::warn!(
393 phase = ?TurnPhase::PersonaEvolution,
394 error = %e,
395 "non-fatal turn write failed — continuing"
396 );
397 turn_warnings.push(TurnWarning {
398 phase: TurnPhase::PersonaEvolution,
399 error: e.clone(),
400 });
401 }
402 let extraction_failed = report.has_errors();
403 if !extraction_failed {
404 tracing::info!(
405 agent_id = %report.agent_id,
406 signals_extracted = report.signals_extracted,
407 signals_applied = report.signals_applied,
408 semantic_nodes_updated = report.semantic_nodes_updated,
409 "ainl-graph-extractor pass completed (scheduled, async)"
410 );
411 }
412 self.hooks.on_extraction_complete(&report);
413 self.persona_cache = None;
414 tracing::debug!(
415 target: "ainl_runtime",
416 duration_ms = t_extract.elapsed().as_millis() as u64,
417 signals_ingested = report.signals_extracted as u64,
418 skipped = false,
419 "extraction_pass_async"
420 );
421 (Some(report), extraction_failed)
422 };
423 self.last_extraction_at_turn = self.turn_count;
424 res
425 } else {
426 tracing::debug!(
427 target: "ainl_runtime",
428 duration_ms = t_extract.elapsed().as_millis() as u64,
429 signals_ingested = 0u64,
430 skipped = true,
431 "extraction_pass_async"
432 );
433 (None, false)
434 };
435
436 if let Err(e) = graph_spawn(Arc::clone(&arc), {
437 let agent_id = agent_id.clone();
438 move |m| try_export_graph_json_armaraos(m.sqlite_store(), &agent_id)
439 })
440 .await
441 {
442 let e = e.message_str().unwrap_or("export").to_string();
443 tracing::warn!(
444 phase = ?TurnPhase::ExportRefresh,
445 error = %e,
446 "non-fatal turn write failed — continuing"
447 );
448 turn_warnings.push(TurnWarning {
449 phase: TurnPhase::ExportRefresh,
450 error: e,
451 });
452 }
453
454 if !self.config.agent_id.is_empty() {
455 let state = RuntimeStateNode {
456 agent_id: self.config.agent_id.clone(),
457 turn_count: self.turn_count,
458 last_extraction_at_turn: self.last_extraction_at_turn,
459 persona_snapshot_json: self
460 .persona_cache
461 .as_ref()
462 .and_then(|p| serde_json::to_string(p).ok()),
463 updated_at: chrono::Utc::now().timestamp(),
464 };
465 let force_fail = std::mem::take(&mut self.test_force_runtime_state_write_failure);
466 let write_res: Result<(), AinlRuntimeError> = if force_fail {
467 Err(AinlRuntimeError::Message(
468 "injected runtime state write failure".into(),
469 ))
470 } else {
471 graph_spawn(Arc::clone(&arc), move |m| m.write_runtime_state(&state)).await
472 };
473 if let Err(e) = write_res {
474 let e = e.to_string();
475 tracing::warn!(
476 phase = ?TurnPhase::RuntimeStatePersist,
477 error = %e,
478 "failed to persist runtime state — cadence will reset on next restart"
479 );
480 turn_warnings.push(TurnWarning {
481 phase: TurnPhase::RuntimeStatePersist,
482 error: e,
483 });
484 }
485 }
486
487 let result = TurnResult {
488 episode_id,
489 persona_prompt_contribution,
490 memory_context,
491 extraction_report,
492 steps_executed: dispatched_count,
493 patch_dispatch_results,
494 status: TurnStatus::Ok,
495 vitals_gate: input.vitals_gate.clone(),
496 vitals_phase: input.vitals_phase.clone(),
497 vitals_trust: input.vitals_trust,
498 };
499
500 let outcome = if turn_warnings.is_empty() {
501 TurnOutcome::Complete(result)
502 } else {
503 TurnOutcome::PartialSuccess {
504 result,
505 warnings: turn_warnings,
506 }
507 };
508
509 self.hooks.on_turn_complete(&outcome);
510 if let Some(ref hooks_async) = self.hooks_async {
511 hooks_async.on_turn_complete(&outcome).await;
512 }
513 Ok(outcome)
514 }
515
516 async fn dispatch_patches_collect_async(
517 &mut self,
518 patches: &[AinlMemoryNode],
519 frame: &HashMap<String, serde_json::Value>,
520 arc: &Arc<Mutex<GraphMemory>>,
521 turn_warnings: &mut Vec<TurnWarning>,
522 ) -> Result<Vec<PatchDispatchResult>, AinlRuntimeError> {
523 let mut out = Vec::new();
524 for node in patches {
525 let res = self
526 .dispatch_one_patch_async(node, frame, Arc::clone(arc))
527 .await?;
528 if let Some(PatchSkipReason::PersistFailed(ref e)) = res.skip_reason {
529 tracing::warn!(
530 phase = ?TurnPhase::FitnessWriteBack,
531 error = %e,
532 "non-fatal turn write failed — continuing"
533 );
534 turn_warnings.push(TurnWarning {
535 phase: TurnPhase::FitnessWriteBack,
536 error: format!("{}: {}", res.label, e),
537 });
538 }
539 out.push(res);
540 }
541 Ok(out)
542 }
543
544 async fn dispatch_one_patch_async(
545 &mut self,
546 node: &AinlMemoryNode,
547 frame: &HashMap<String, serde_json::Value>,
548 arc: Arc<Mutex<GraphMemory>>,
549 ) -> Result<PatchDispatchResult, AinlRuntimeError> {
550 let label_default = String::new();
551 let (label_src, pv, retired, reads, fitness_opt) = match &node.node_type {
552 AinlNodeType::Procedural { procedural } => (
553 procedural_label(procedural),
554 procedural.patch_version,
555 procedural.retired,
556 procedural.declared_reads.clone(),
557 procedural.fitness,
558 ),
559 _ => {
560 return Ok(PatchDispatchResult {
561 label: label_default,
562 patch_version: 0,
563 fitness_before: 0.0,
564 fitness_after: 0.0,
565 dispatched: false,
566 skip_reason: Some(PatchSkipReason::NotProcedural),
567 adapter_output: None,
568 adapter_name: None,
569 });
570 }
571 };
572
573 if pv == 0 {
574 return Ok(PatchDispatchResult {
575 label: label_src,
576 patch_version: pv,
577 fitness_before: fitness_opt.unwrap_or(0.5),
578 fitness_after: fitness_opt.unwrap_or(0.5),
579 dispatched: false,
580 skip_reason: Some(PatchSkipReason::ZeroVersion),
581 adapter_output: None,
582 adapter_name: None,
583 });
584 }
585 if retired {
586 return Ok(PatchDispatchResult {
587 label: label_src.clone(),
588 patch_version: pv,
589 fitness_before: fitness_opt.unwrap_or(0.5),
590 fitness_after: fitness_opt.unwrap_or(0.5),
591 dispatched: false,
592 skip_reason: Some(PatchSkipReason::Retired),
593 adapter_output: None,
594 adapter_name: None,
595 });
596 }
597 for key in &reads {
598 if !frame.contains_key(key) {
599 return Ok(PatchDispatchResult {
600 label: label_src.clone(),
601 patch_version: pv,
602 fitness_before: fitness_opt.unwrap_or(0.5),
603 fitness_after: fitness_opt.unwrap_or(0.5),
604 dispatched: false,
605 skip_reason: Some(PatchSkipReason::MissingDeclaredRead(key.clone())),
606 adapter_output: None,
607 adapter_name: None,
608 });
609 }
610 }
611
612 let patch_label = label_src.clone();
613 let adapter_key = patch_label.as_str();
614 let ctx = PatchDispatchContext {
615 patch_label: adapter_key,
616 node,
617 frame,
618 };
619 let (adapter_output, adapter_name) = if let Some(adapter) = self
620 .adapter_registry
621 .get(adapter_key)
622 .or_else(|| self.adapter_registry.get(GraphPatchAdapter::NAME))
623 {
624 let aname = adapter.name().to_string();
625 match adapter.execute_patch(&ctx) {
626 Ok(output) => {
627 tracing::debug!(
628 label = %patch_label,
629 adapter = %aname,
630 "adapter executed patch (async)"
631 );
632 (Some(output), Some(aname))
633 }
634 Err(e) => {
635 tracing::warn!(
636 label = %patch_label,
637 adapter = %aname,
638 error = %e,
639 "adapter execution failed — continuing as metadata dispatch"
640 );
641 (None, Some(aname))
642 }
643 }
644 } else {
645 (None, None)
646 };
647
648 let fitness_before = fitness_opt.unwrap_or(0.5);
649 let fitness_after = 0.2_f32 * 1.0 + 0.8 * fitness_before;
650
651 let nid = node.id;
652 let updated = match graph_spawn(Arc::clone(&arc), move |m| {
653 let store = m.sqlite_store();
654 store.read_node(nid)
655 })
656 .await?
657 {
658 Some(mut n) => {
659 if let AinlNodeType::Procedural { ref mut procedural } = n.node_type {
660 procedural.fitness = Some(fitness_after);
661 }
662 n
663 }
664 None => {
665 return Ok(PatchDispatchResult {
666 label: label_src,
667 patch_version: pv,
668 fitness_before,
669 fitness_after: fitness_before,
670 dispatched: false,
671 skip_reason: Some(PatchSkipReason::MissingDeclaredRead("node_row".into())),
672 adapter_output,
673 adapter_name,
674 });
675 }
676 };
677
678 if self.test_force_fitness_write_failure {
679 self.test_force_fitness_write_failure = false;
680 let e = "injected fitness write failure".to_string();
681 return Ok(PatchDispatchResult {
682 label: label_src.clone(),
683 patch_version: pv,
684 fitness_before,
685 fitness_after: fitness_before,
686 dispatched: false,
687 skip_reason: Some(PatchSkipReason::PersistFailed(e)),
688 adapter_output,
689 adapter_name,
690 });
691 }
692
693 let updated_clone = updated.clone();
694 if let Err(e) = graph_spawn(arc, move |m| m.write_node(&updated_clone)).await {
695 return Ok(PatchDispatchResult {
696 label: label_src.clone(),
697 patch_version: pv,
698 fitness_before,
699 fitness_after: fitness_before,
700 dispatched: false,
701 skip_reason: Some(PatchSkipReason::PersistFailed(
702 e.message_str().unwrap_or("write").to_string(),
703 )),
704 adapter_output,
705 adapter_name,
706 });
707 }
708
709 self.hooks
710 .on_patch_dispatched(label_src.as_str(), fitness_after);
711 if let Some(ref hooks_async) = self.hooks_async {
712 let hook_ctx = PatchDispatchContext {
713 patch_label: adapter_key,
714 node,
715 frame,
716 };
717 let _ = hooks_async.on_patch_dispatched(&hook_ctx).await;
718 }
719
720 Ok(PatchDispatchResult {
721 label: label_src,
722 patch_version: pv,
723 fitness_before,
724 fitness_after,
725 dispatched: true,
726 skip_reason: None,
727 adapter_output,
728 adapter_name,
729 })
730 }
731}