1use super::{ExecutionError, StageExecutor};
18use noether_core::stage::StageId;
19use noether_core::types::NType;
20use noether_store::StageStore;
21use serde_json::{json, Value};
22use std::collections::HashMap;
23use std::sync::Mutex;
24
25use crate::index::embedding::EmbeddingProvider;
26use crate::llm::{LlmConfig, LlmProvider, Message};
27
28#[derive(Clone)]
31struct CachedStage {
32 id: String,
33 description: String,
34 input_display: String,
35 output_display: String,
36 lifecycle: String,
37 effects: Vec<String>,
38 examples_count: usize,
39}
40
41fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44 if a.len() != b.len() || a.is_empty() {
45 return 0.0;
46 }
47 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
48 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
49 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
50 if na == 0.0 || nb == 0.0 {
51 0.0
52 } else {
53 dot / (na * nb)
54 }
55}
56
57pub struct RuntimeExecutor {
60 llm: Option<Box<dyn LlmProvider>>,
61 llm_config: LlmConfig,
62 embedding_provider: Option<Box<dyn EmbeddingProvider>>,
63 descriptions: HashMap<String, String>,
65 stage_cache: Vec<CachedStage>,
67 stage_embeddings: HashMap<String, Vec<f32>>,
69 llm_dedup_cache: Mutex<HashMap<String, Value>>,
71}
72
73impl RuntimeExecutor {
74 pub fn from_store(store: &dyn StageStore) -> Self {
77 let mut descriptions = HashMap::new();
78 let mut stage_cache = Vec::new();
79
80 for stage in store.list(None) {
81 descriptions.insert(stage.id.0.clone(), stage.description.clone());
82
83 let effects: Vec<String> = stage
84 .signature
85 .effects
86 .iter()
87 .map(|e| format!("{e:?}"))
88 .collect();
89
90 stage_cache.push(CachedStage {
91 id: stage.id.0.clone(),
92 description: stage.description.clone(),
93 input_display: format!("{}", stage.signature.input),
94 output_display: format!("{}", stage.signature.output),
95 lifecycle: format!("{:?}", stage.lifecycle).to_lowercase(),
96 effects,
97 examples_count: stage.examples.len(),
98 });
99 }
100
101 Self {
102 llm: None,
103 llm_config: LlmConfig::default(),
104 embedding_provider: None,
105 descriptions,
106 stage_cache,
107 stage_embeddings: HashMap::new(),
108 llm_dedup_cache: Mutex::new(HashMap::new()),
109 }
110 }
111
112 pub fn with_llm(mut self, llm: Box<dyn LlmProvider>, config: LlmConfig) -> Self {
114 self.llm = Some(llm);
115 self.llm_config = config;
116 self
117 }
118
119 pub fn with_embedding(mut self, provider: Box<dyn EmbeddingProvider>) -> Self {
122 let mut embeddings = HashMap::new();
124 for stage in &self.stage_cache {
125 if let Ok(emb) = provider.embed(&stage.description) {
126 embeddings.insert(stage.id.clone(), emb);
127 }
128 }
129 self.stage_embeddings = embeddings;
130 self.embedding_provider = Some(provider);
131 self
132 }
133
134 pub fn set_llm(&mut self, llm: Box<dyn LlmProvider>, config: LlmConfig) {
136 self.llm = Some(llm);
137 self.llm_config = config;
138 }
139
140 pub fn has_implementation(&self, stage_id: &StageId) -> bool {
142 matches!(
143 self.descriptions.get(&stage_id.0).map(|s| s.as_str()),
144 Some(
145 "Generate text completion using a language model"
146 | "Generate a vector embedding for text"
147 | "Classify text into one of the provided categories"
148 | "Extract structured data from text according to a schema"
149 | "Get detailed information about a stage by its ID"
150 | "Search the stage store by semantic query"
151 | "Check if one type is a structural subtype of another"
152 | "Verify that a composition graph type-checks correctly"
153 | "Register a new stage in the store"
154 | "Retrieve the execution trace of a past composition"
155 )
156 )
157 }
158
159 fn dispatch(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
162 let desc = self
163 .descriptions
164 .get(&stage_id.0)
165 .map(|s| s.as_str())
166 .unwrap_or("");
167
168 match desc {
169 "Generate text completion using a language model" => self.llm_complete(stage_id, input),
170 "Generate a vector embedding for text" => self.llm_embed(stage_id, input),
171 "Classify text into one of the provided categories" => {
172 self.llm_classify(stage_id, input)
173 }
174 "Extract structured data from text according to a schema" => {
175 self.llm_extract(stage_id, input)
176 }
177 "Get detailed information about a stage by its ID" => {
178 self.stage_describe(stage_id, input)
179 }
180 "Search the stage store by semantic query" => self.store_search(stage_id, input),
181 "Check if one type is a structural subtype of another" => type_check(stage_id, input),
182 "Verify that a composition graph type-checks correctly" => {
183 self.composition_verify(stage_id, input)
184 }
185 "Register a new stage in the store" => {
186 Err(ExecutionError::StageFailed {
189 stage_id: stage_id.clone(),
190 message: "store_add cannot be called inside a composition graph — use `noether compose` or the synthesis API to register new stages".into(),
191 })
192 }
193 "Retrieve the execution trace of a past composition" => {
194 Err(ExecutionError::StageFailed {
197 stage_id: stage_id.clone(),
198 message: "trace_read cannot be called inside a composition graph — use `noether trace <composition_id>` from the CLI".into(),
199 })
200 }
201 _ => Err(ExecutionError::StageNotFound(stage_id.clone())),
202 }
203 }
204
205 fn require_llm(&self, stage_id: &StageId) -> Result<&dyn LlmProvider, ExecutionError> {
208 self.llm.as_deref().ok_or_else(|| ExecutionError::StageFailed {
209 stage_id: stage_id.clone(),
210 message: "LLM provider not configured (set VERTEX_AI_PROJECT, VERTEX_AI_TOKEN, VERTEX_AI_LOCATION)".into(),
211 })
212 }
213
214 fn llm_complete(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
215 let llm = self.require_llm(stage_id)?;
216
217 let prompt = input["prompt"].as_str().unwrap_or("").to_string();
218 let model = input["model"]
219 .as_str()
220 .unwrap_or(&self.llm_config.model)
221 .to_string();
222 let max_tokens = input["max_tokens"]
223 .as_u64()
224 .map(|v| v as u32)
225 .unwrap_or(self.llm_config.max_tokens);
226 let temperature = input["temperature"]
227 .as_f64()
228 .map(|v| v as f32)
229 .unwrap_or(self.llm_config.temperature);
230 let system_opt = input["system"].as_str();
231
232 let mut messages = vec![];
233 if let Some(sys) = system_opt {
234 messages.push(Message::system(sys));
235 }
236 messages.push(Message::user(&prompt));
237
238 let cfg = LlmConfig {
239 model: model.clone(),
240 max_tokens,
241 temperature,
242 };
243
244 let dedup_key = {
247 use sha2::{Digest, Sha256};
248 let key_data = format!("{}:{}:{}", model, system_opt.unwrap_or(""), prompt);
249 hex::encode(Sha256::digest(key_data.as_bytes()))
250 };
251
252 {
253 let cache = self.llm_dedup_cache.lock().unwrap();
254 if let Some(cached) = cache.get(&dedup_key) {
255 let mut result = cached.clone();
256 result["from_llm_cache"] = json!(true);
257 return Ok(result);
258 }
259 }
260
261 let text = llm
262 .complete(&messages, &cfg)
263 .map_err(|e| ExecutionError::StageFailed {
264 stage_id: stage_id.clone(),
265 message: format!("LLM error: {e}"),
266 })?;
267
268 let tokens_used = text.split_whitespace().count() as u64;
269
270 let result = json!({
271 "text": text,
272 "tokens_used": tokens_used,
273 "model": model,
274 "from_llm_cache": false,
275 });
276
277 self.llm_dedup_cache
278 .lock()
279 .unwrap()
280 .insert(dedup_key, result.clone());
281
282 Ok(result)
283 }
284
285 fn llm_embed(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
286 let text = input["text"].as_str().unwrap_or("").to_string();
287 let model_override = input["model"].as_str().map(|s| s.to_string());
288
289 if let Some(ep) = &self.embedding_provider {
291 let emb = ep.embed(&text).map_err(|e| ExecutionError::StageFailed {
292 stage_id: stage_id.clone(),
293 message: format!("embedding provider error: {e}"),
294 })?;
295 let dims = emb.len() as u64;
296 let model = model_override.unwrap_or_else(|| "embedding-model".into());
297 return Ok(json!({
298 "embedding": emb,
299 "dimensions": dims,
300 "model": model,
301 }));
302 }
303
304 let llm = self.require_llm(stage_id)?;
306 let model = model_override.unwrap_or_else(|| "text-embedding-004".to_string());
307
308 let prompt = format!(
309 "Generate a compact 8-dimensional embedding vector for this text as a JSON array of floats: \"{text}\". Respond ONLY with a JSON array like [0.1, -0.2, ...]."
310 );
311 let messages = vec![
312 Message::system("You are an embedding model. Respond only with a JSON float array."),
313 Message::user(&prompt),
314 ];
315 let cfg = LlmConfig {
316 model: model.clone(),
317 max_tokens: 128,
318 temperature: 0.0,
319 };
320
321 let response = llm
322 .complete(&messages, &cfg)
323 .map_err(|e| ExecutionError::StageFailed {
324 stage_id: stage_id.clone(),
325 message: format!("LLM error: {e}"),
326 })?;
327
328 let embedding: Value =
329 extract_json_array(&response).ok_or_else(|| ExecutionError::StageFailed {
330 stage_id: stage_id.clone(),
331 message: format!("could not parse embedding from LLM response: {response:?}"),
332 })?;
333
334 let dims = embedding.as_array().map(|a| a.len()).unwrap_or(0) as u64;
335
336 Ok(json!({
337 "embedding": embedding,
338 "dimensions": dims,
339 "model": model,
340 }))
341 }
342
343 fn llm_classify(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
344 let llm = self.require_llm(stage_id)?;
345
346 let text = input["text"].as_str().unwrap_or("").to_string();
347 let model = input["model"]
348 .as_str()
349 .unwrap_or(&self.llm_config.model)
350 .to_string();
351 let categories: Vec<String> = input["categories"]
352 .as_array()
353 .map(|a| {
354 a.iter()
355 .filter_map(|v| v.as_str())
356 .map(|s| s.to_string())
357 .collect()
358 })
359 .unwrap_or_default();
360
361 if categories.is_empty() {
362 return Err(ExecutionError::StageFailed {
363 stage_id: stage_id.clone(),
364 message: "categories list is empty".into(),
365 });
366 }
367
368 let cats_str = categories.join(", ");
369 let prompt = format!(
370 "Classify the following text into EXACTLY ONE of these categories: {cats_str}\n\nText: \"{text}\"\n\nRespond with ONLY valid JSON: {{\"category\": \"<one of the categories>\", \"confidence\": <0.0-1.0>}}"
371 );
372
373 let messages = vec![
374 Message::system(
375 "You are a text classifier. Always respond with valid JSON only. No explanation.",
376 ),
377 Message::user(&prompt),
378 ];
379 let cfg = LlmConfig {
380 model: model.clone(),
381 max_tokens: 64,
382 temperature: 0.0,
383 };
384
385 let response = llm
386 .complete(&messages, &cfg)
387 .map_err(|e| ExecutionError::StageFailed {
388 stage_id: stage_id.clone(),
389 message: format!("LLM error: {e}"),
390 })?;
391
392 let parsed: Value =
393 extract_json_object(&response).ok_or_else(|| ExecutionError::StageFailed {
394 stage_id: stage_id.clone(),
395 message: format!("could not parse classification JSON from: {response:?}"),
396 })?;
397
398 let category = parsed["category"].as_str().unwrap_or("").trim().to_string();
399 if !categories.contains(&category) {
400 return Err(ExecutionError::StageFailed {
401 stage_id: stage_id.clone(),
402 message: format!(
403 "LLM returned unknown category {category:?}; expected one of: {cats_str}"
404 ),
405 });
406 }
407
408 let confidence = parsed["confidence"].as_f64().unwrap_or(1.0);
409
410 Ok(json!({
411 "category": category,
412 "confidence": confidence,
413 "model": model,
414 }))
415 }
416
417 fn llm_extract(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
418 let llm = self.require_llm(stage_id)?;
419
420 let text = input["text"].as_str().unwrap_or("").to_string();
421 let model = input["model"]
422 .as_str()
423 .unwrap_or(&self.llm_config.model)
424 .to_string();
425 let schema = input.get("schema").cloned().unwrap_or(json!({}));
426 let schema_str = serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string());
427
428 let prompt = format!(
429 "Extract structured data from the following text.\nSchema: {schema_str}\nText: \"{text}\"\n\nRespond with ONLY a valid JSON object matching the schema. No explanation."
430 );
431
432 let messages = vec![
433 Message::system(
434 "You are a structured data extractor. Always respond with valid JSON only.",
435 ),
436 Message::user(&prompt),
437 ];
438 let cfg = LlmConfig {
439 model: model.clone(),
440 max_tokens: 512,
441 temperature: 0.0,
442 };
443
444 let response = llm
445 .complete(&messages, &cfg)
446 .map_err(|e| ExecutionError::StageFailed {
447 stage_id: stage_id.clone(),
448 message: format!("LLM error: {e}"),
449 })?;
450
451 let extracted =
452 extract_json_object(&response).ok_or_else(|| ExecutionError::StageFailed {
453 stage_id: stage_id.clone(),
454 message: format!("could not parse extraction JSON from: {response:?}"),
455 })?;
456
457 Ok(json!({
458 "extracted": extracted,
459 "model": model,
460 }))
461 }
462
463 fn stage_describe(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
466 let id = input["id"].as_str().unwrap_or("").to_string();
467
468 let cached = self
469 .stage_cache
470 .iter()
471 .find(|s| s.id == id || s.id.starts_with(&id))
472 .ok_or_else(|| ExecutionError::StageFailed {
473 stage_id: stage_id.clone(),
474 message: format!("stage {id:?} not found"),
475 })?;
476
477 Ok(json!({
478 "id": cached.id,
479 "description": cached.description,
480 "input": cached.input_display,
481 "output": cached.output_display,
482 "effects": cached.effects,
483 "lifecycle": cached.lifecycle,
484 "examples_count": cached.examples_count,
485 }))
486 }
487
488 fn store_search(&self, _stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
494 let query = input["query"].as_str().unwrap_or("");
495 let limit = input["limit"].as_u64().unwrap_or(10) as usize;
496
497 if let Some(ep) = &self.embedding_provider {
498 if let Ok(query_emb) = ep.embed(query) {
500 let mut scored: Vec<(f32, &CachedStage)> = self
501 .stage_cache
502 .iter()
503 .filter_map(|s| {
504 self.stage_embeddings
505 .get(&s.id)
506 .map(|emb| (cosine_similarity(&query_emb, emb), s))
507 })
508 .collect();
509
510 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
511
512 let results: Vec<Value> = scored
513 .into_iter()
514 .take(limit)
515 .map(|(score, s)| {
516 json!({
517 "id": s.id,
518 "description": s.description,
519 "input": s.input_display,
520 "output": s.output_display,
521 "score": score,
522 })
523 })
524 .collect();
525
526 return Ok(Value::Array(results));
527 }
528 }
529
530 let query_lc = query.to_lowercase();
532 let results: Vec<Value> = self
533 .stage_cache
534 .iter()
535 .filter(|s| {
536 s.description.to_lowercase().contains(&query_lc)
537 || s.input_display.to_lowercase().contains(&query_lc)
538 || s.output_display.to_lowercase().contains(&query_lc)
539 })
540 .take(limit)
541 .map(|s| {
542 json!({
543 "id": s.id,
544 "description": s.description,
545 "input": s.input_display,
546 "output": s.output_display,
547 "score": 1.0,
548 })
549 })
550 .collect();
551
552 Ok(Value::Array(results))
553 }
554
555 fn composition_verify(
560 &self,
561 stage_id: &StageId,
562 input: &Value,
563 ) -> Result<Value, ExecutionError> {
564 let stage_ids: Vec<&str> = input["stages"]
565 .as_array()
566 .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
567 .unwrap_or_default();
568
569 let operators: Vec<&str> = input["operators"]
570 .as_array()
571 .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
572 .unwrap_or_default();
573
574 let mut errors: Vec<String> = vec![];
575 let mut warnings: Vec<String> = vec![];
576
577 if stage_ids.is_empty() {
578 warnings.push("empty composition".into());
579 return Ok(json!({ "valid": true, "errors": errors, "warnings": warnings }));
580 }
581
582 let valid_ops = [
584 "sequential",
585 "parallel",
586 "branch",
587 "fanout",
588 "merge",
589 "retry",
590 ];
591 for op in &operators {
592 let op_lc = op.to_lowercase();
593 if !valid_ops.contains(&op_lc.as_str()) {
594 errors.push(format!("unknown operator: {op}"));
595 }
596 }
597
598 let id_to_cache: HashMap<&str, &CachedStage> = self
600 .stage_cache
601 .iter()
602 .map(|s| (s.id.as_str(), s))
603 .collect();
604
605 let mut resolved_stages: Vec<&CachedStage> = vec![];
606 for sid in &stage_ids {
607 match id_to_cache.get(sid) {
608 Some(s) => {
609 if s.lifecycle == "deprecated" {
610 warnings.push(format!("stage {} ({}) is deprecated", sid, s.description));
611 }
612 if s.lifecycle == "tombstone" {
613 errors.push(format!(
614 "stage {} is a tombstone and cannot be executed",
615 sid
616 ));
617 }
618 resolved_stages.push(s);
619 }
620 None => {
621 errors.push(format!("stage {sid} not found in store"));
622 }
623 }
624 }
625
626 if operators.iter().any(|op| op.to_lowercase() == "sequential") && resolved_stages.len() > 1
629 {
630 for i in 0..resolved_stages.len() - 1 {
631 let out_str = &resolved_stages[i].output_display;
632 let in_str = &resolved_stages[i + 1].input_display;
633
634 let out_type: Option<NType> = serde_json::from_str(&format!("\"{}\"", out_str))
635 .ok()
636 .or_else(|| parse_ntype_display(out_str));
637 let in_type: Option<NType> = serde_json::from_str(&format!("\"{}\"", in_str))
638 .ok()
639 .or_else(|| parse_ntype_display(in_str));
640
641 if let (Some(out), Some(inp)) = (out_type, in_type) {
642 use noether_core::types::{is_subtype_of, TypeCompatibility};
643 if let TypeCompatibility::Incompatible(reason) = is_subtype_of(&out, &inp) {
644 errors.push(format!(
645 "type mismatch between stages {} and {}: {} is not compatible with {} ({})",
646 stage_ids[i], stage_ids[i + 1], out_str, in_str, reason
647 ));
648 }
649 }
650 }
652 }
653
654 let _ = stage_id;
656
657 let valid = errors.is_empty();
658 Ok(json!({
659 "valid": valid,
660 "errors": errors,
661 "warnings": warnings,
662 }))
663 }
664}
665
666impl StageExecutor for RuntimeExecutor {
667 fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
668 self.dispatch(stage_id, input)
669 }
670}
671
672fn type_check(stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
676 use noether_core::types::{is_subtype_of, TypeCompatibility};
677
678 let sub = parse_ntype_input(&input["sub"]).ok_or_else(|| ExecutionError::StageFailed {
679 stage_id: stage_id.clone(),
680 message: format!("could not parse sub type from: {}", input["sub"]),
681 })?;
682
683 let sup = parse_ntype_input(&input["sup"]).ok_or_else(|| ExecutionError::StageFailed {
684 stage_id: stage_id.clone(),
685 message: format!("could not parse sup type from: {}", input["sup"]),
686 })?;
687
688 match is_subtype_of(&sub, &sup) {
689 TypeCompatibility::Compatible => Ok(json!({"compatible": true, "reason": null})),
690 TypeCompatibility::Incompatible(reason) => {
691 Ok(json!({"compatible": false, "reason": format!("{reason}")}))
692 }
693 }
694}
695
696fn parse_ntype_input(v: &Value) -> Option<NType> {
702 if let Some(s) = v.as_str() {
703 match s {
704 "Text" => return Some(NType::Text),
705 "Number" => return Some(NType::Number),
706 "Bool" => return Some(NType::Bool),
707 "Any" => return Some(NType::Any),
708 "Null" => return Some(NType::Null),
709 "Bytes" => return Some(NType::Bytes),
710 _ => {}
711 }
712 }
713 serde_json::from_value(v.clone()).ok()
714}
715
716fn parse_ntype_display(s: &str) -> Option<NType> {
719 match s.trim() {
720 "Text" => Some(NType::Text),
721 "Number" => Some(NType::Number),
722 "Bool" => Some(NType::Bool),
723 "Any" => Some(NType::Any),
724 "Null" => Some(NType::Null),
725 "Bytes" => Some(NType::Bytes),
726 "VNode" => Some(NType::VNode),
727 _ => None,
728 }
729}
730
731fn extract_json_array(s: &str) -> Option<Value> {
733 let start = s.find('[')?;
734 let end = s.rfind(']').map(|i| i + 1)?;
735 serde_json::from_str(&s[start..end]).ok()
736}
737
738fn extract_json_object(s: &str) -> Option<Value> {
740 let start = s.find('{')?;
741 let end = s.rfind('}').map(|i| i + 1)?;
742 serde_json::from_str(&s[start..end]).ok()
743}
744
745#[cfg(test)]
748mod tests {
749 use super::*;
750 use noether_core::stdlib::load_stdlib;
751 use noether_store::MemoryStore;
752
753 fn stdlib_runtime() -> RuntimeExecutor {
754 let mut store = MemoryStore::new();
755 for s in load_stdlib() {
756 let _ = store.put(s);
757 }
758 RuntimeExecutor::from_store(&store)
759 }
760
761 #[test]
762 fn type_check_compatible() {
763 let rt = stdlib_runtime();
764 let id = rt
765 .descriptions
766 .iter()
767 .find(|(_, v)| v.contains("structural subtype"))
768 .map(|(k, _)| StageId(k.clone()))
769 .unwrap();
770 let result = rt
771 .execute(&id, &json!({"sub": "Text", "sup": "Text"}))
772 .unwrap();
773 assert_eq!(result["compatible"], json!(true));
774 assert_eq!(result["reason"], json!(null));
775 }
776
777 #[test]
778 fn type_check_incompatible() {
779 let rt = stdlib_runtime();
780 let id = rt
781 .descriptions
782 .iter()
783 .find(|(_, v)| v.contains("structural subtype"))
784 .map(|(k, _)| StageId(k.clone()))
785 .unwrap();
786 let result = rt
787 .execute(&id, &json!({"sub": "Text", "sup": "Number"}))
788 .unwrap();
789 assert_eq!(result["compatible"], json!(false));
790 assert!(result["reason"].is_string());
791 }
792
793 #[test]
794 fn stage_describe_includes_effects() {
795 let rt = stdlib_runtime();
796 let describe_id = rt
797 .descriptions
798 .iter()
799 .find(|(_, v)| v.contains("Get detailed information"))
800 .map(|(k, _)| StageId(k.clone()))
801 .unwrap();
802 let to_text_id = rt
803 .descriptions
804 .iter()
805 .find(|(_, v)| v.contains("Convert any value to its text"))
806 .map(|(k, _)| k.clone())
807 .unwrap();
808
809 let result = rt
810 .execute(&describe_id, &json!({"id": to_text_id}))
811 .unwrap();
812 assert_eq!(result["id"], json!(to_text_id));
813 assert!(result["description"].as_str().unwrap().contains("text"));
814 assert!(result["effects"].is_array(), "effects should be an array");
816 assert!(result["examples_count"].as_u64().unwrap() > 0);
817 }
818
819 #[test]
820 fn store_search_finds_stages() {
821 let rt = stdlib_runtime();
822 let search_id = rt
823 .descriptions
824 .iter()
825 .find(|(_, v)| v.contains("Search the stage store"))
826 .map(|(k, _)| StageId(k.clone()))
827 .unwrap();
828 let result = rt
829 .execute(&search_id, &json!({"query": "sort", "limit": 5}))
830 .unwrap();
831 let hits = result.as_array().unwrap();
832 assert!(!hits.is_empty());
833 assert!(hits
834 .iter()
835 .any(|h| h["description"].as_str().unwrap_or("").contains("Sort")));
836 }
837
838 #[test]
839 fn store_search_with_embedding_provider() {
840 use crate::index::embedding::MockEmbeddingProvider;
841 let mut store = MemoryStore::new();
842 for s in load_stdlib() {
843 let _ = store.put(s);
844 }
845 let rt = RuntimeExecutor::from_store(&store)
846 .with_embedding(Box::new(MockEmbeddingProvider::new(32)));
847
848 let search_id = rt
849 .descriptions
850 .iter()
851 .find(|(_, v)| v.contains("Search the stage store"))
852 .map(|(k, _)| StageId(k.clone()))
853 .unwrap();
854 let result = rt
855 .execute(&search_id, &json!({"query": "sort list", "limit": 10}))
856 .unwrap();
857 let hits = result.as_array().unwrap();
858 assert!(!hits.is_empty());
859 for h in hits {
861 let score = h["score"].as_f64().unwrap();
862 assert!((0.0..=1.0).contains(&score), "score {score} out of range");
863 }
864 }
865
866 #[test]
867 fn composition_verify_valid_stages() {
868 let rt = stdlib_runtime();
869 let verify_id = rt
870 .descriptions
871 .iter()
872 .find(|(_, v)| v.contains("Verify that a composition graph"))
873 .map(|(k, _)| StageId(k.clone()))
874 .unwrap();
875
876 let ids: Vec<String> = rt
878 .stage_cache
879 .iter()
880 .take(2)
881 .map(|s| s.id.clone())
882 .collect();
883
884 let result = rt
885 .execute(
886 &verify_id,
887 &json!({
888 "stages": ids,
889 "operators": ["sequential"]
890 }),
891 )
892 .unwrap();
893 assert!(result["errors"].is_array());
895 assert!(result["warnings"].is_array());
896 }
897
898 #[test]
899 fn composition_verify_unknown_stage_is_error() {
900 let rt = stdlib_runtime();
901 let verify_id = rt
902 .descriptions
903 .iter()
904 .find(|(_, v)| v.contains("Verify that a composition graph"))
905 .map(|(k, _)| StageId(k.clone()))
906 .unwrap();
907
908 let result = rt
909 .execute(
910 &verify_id,
911 &json!({
912 "stages": ["nonexistent-stage-id"],
913 "operators": []
914 }),
915 )
916 .unwrap();
917 assert_eq!(result["valid"], json!(false));
918 assert!(result["errors"]
919 .as_array()
920 .unwrap()
921 .iter()
922 .any(|e| { e.as_str().unwrap_or("").contains("not found") }));
923 }
924
925 #[test]
926 fn llm_complete_fails_gracefully_without_llm() {
927 let rt = stdlib_runtime();
928 let llm_id = rt
929 .descriptions
930 .iter()
931 .find(|(_, v)| v.contains("Generate text completion"))
932 .map(|(k, _)| StageId(k.clone()))
933 .unwrap();
934 let result = rt.execute(
935 &llm_id,
936 &json!({"prompt": "Hello", "model": null, "max_tokens": null, "temperature": null, "system": null}),
937 );
938 assert!(result.is_err());
939 let msg = result.unwrap_err().to_string();
940 assert!(
941 msg.contains("LLM provider not configured"),
942 "expected config error, got: {msg}"
943 );
944 }
945
946 #[test]
947 fn llm_embed_uses_embedding_provider_when_available() {
948 use crate::index::embedding::MockEmbeddingProvider;
949 let mut store = MemoryStore::new();
950 for s in load_stdlib() {
951 let _ = store.put(s);
952 }
953 let rt = RuntimeExecutor::from_store(&store)
954 .with_embedding(Box::new(MockEmbeddingProvider::new(16)));
955
956 let embed_id = rt
957 .descriptions
958 .iter()
959 .find(|(_, v)| v.contains("Generate a vector embedding"))
960 .map(|(k, _)| StageId(k.clone()))
961 .unwrap();
962
963 let result = rt
964 .execute(&embed_id, &json!({"text": "hello world", "model": null}))
965 .unwrap();
966 assert_eq!(result["dimensions"], json!(16u64));
967 assert_eq!(result["embedding"].as_array().unwrap().len(), 16);
968 }
969
970 #[test]
972 fn llm_dedup_cache_concurrent_access() {
973 use crate::llm::MockLlmProvider;
974 use std::sync::Arc;
975
976 let mock_response = r#"{"category":"positive","confidence":0.99,"model":"mock"}"#;
977
978 let mut store = MemoryStore::new();
979 for s in load_stdlib() {
980 let _ = store.put(s);
981 }
982
983 let rt = RuntimeExecutor::from_store(&store).with_llm(
984 Box::new(MockLlmProvider::new(mock_response)),
985 LlmConfig::default(),
986 );
987 let rt = Arc::new(rt);
988
989 let classify_id = rt
990 .descriptions
991 .iter()
992 .find(|(_, v)| v.contains("Classify text into one of"))
993 .map(|(k, _)| StageId(k.clone()))
994 .expect("classify_text stage not found");
995
996 let input = serde_json::json!({
997 "text": "I love this product",
998 "categories": ["positive", "negative", "neutral"],
999 "model": null
1000 });
1001
1002 let results: Vec<_> = std::thread::scope(|s| {
1003 let handles: Vec<_> = (0..16)
1004 .map(|_| {
1005 let rt = Arc::clone(&rt);
1006 let id = classify_id.clone();
1007 let inp = input.clone();
1008 s.spawn(move || rt.execute(&id, &inp))
1009 })
1010 .collect();
1011 handles.into_iter().map(|h| h.join().unwrap()).collect()
1012 });
1013
1014 assert_eq!(results.len(), 16);
1015 let first = results[0].as_ref().expect("first result must be Ok");
1016 for (i, r) in results.iter().enumerate() {
1017 let val = r
1018 .as_ref()
1019 .unwrap_or_else(|e| panic!("thread {i} failed: {e}"));
1020 assert_eq!(
1021 val["category"], first["category"],
1022 "thread {i} returned different category"
1023 );
1024 }
1025 assert_eq!(first["category"].as_str().unwrap(), "positive");
1026 }
1027}