1use std::any::Any;
2use std::collections::BTreeMap;
3use std::sync::Arc;
4
5use lash_core::{
6 PromptContribution, ProtocolSessionExtension, ProtocolTurnExtension,
7 ProtocolTurnExtensionHandle, TurnInput,
8};
9
10pub(crate) const RLM_TURN_INPUT_PLUGIN_ID: &str = "rlm";
11use lashlang::{
12 ProjectedBindingError, ProjectedBindings, ProjectedHostValue, ProjectedValue,
13 Value as FlowValue,
14};
15
16#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
17pub struct ProjectionRef {
18 pub kind: String,
19 pub key: serde_json::Value,
20}
21
22impl ProjectionRef {
23 pub fn new(kind: impl Into<String>, key: serde_json::Value) -> Self {
24 Self {
25 kind: kind.into(),
26 key,
27 }
28 }
29}
30
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct ProjectionResolveError {
33 message: String,
34}
35
36impl ProjectionResolveError {
37 pub fn unavailable(reference: &ProjectionRef) -> Self {
38 Self {
39 message: format!(
40 "projection ref unavailable: kind `{}`, key {}",
41 reference.kind, reference.key
42 ),
43 }
44 }
45
46 pub fn invalid(message: impl Into<String>) -> Self {
47 Self {
48 message: message.into(),
49 }
50 }
51}
52
53impl std::fmt::Display for ProjectionResolveError {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str(&self.message)
56 }
57}
58
59impl std::error::Error for ProjectionResolveError {}
60
61#[async_trait::async_trait]
62pub trait ProjectionResolver: Send + Sync {
63 async fn resolve_projection(
64 &self,
65 reference: &ProjectionRef,
66 ) -> Result<Arc<dyn ProjectedHostValue>, ProjectionResolveError>;
67}
68
69#[derive(Clone, Default)]
70pub struct ProjectionRegistry {
71 memory: Arc<std::sync::RwLock<BTreeMap<String, Arc<dyn ProjectedHostValue>>>>,
72}
73
74impl ProjectionRegistry {
75 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub fn register_memory(&self, value: Arc<dyn ProjectedHostValue>) -> ProjectionRef {
80 let key = uuid::Uuid::new_v4().to_string();
81 self.memory
82 .write()
83 .expect("projection registry lock")
84 .insert(key.clone(), value);
85 ProjectionRef::new("memory", serde_json::Value::String(key))
86 }
87}
88
89#[async_trait::async_trait]
90impl ProjectionResolver for ProjectionRegistry {
91 async fn resolve_projection(
92 &self,
93 reference: &ProjectionRef,
94 ) -> Result<Arc<dyn ProjectedHostValue>, ProjectionResolveError> {
95 if reference.kind != "memory" {
96 return Err(ProjectionResolveError::unavailable(reference));
97 }
98 let Some(key) = reference.key.as_str() else {
99 return Err(ProjectionResolveError::invalid(
100 "memory projection ref key must be a string",
101 ));
102 };
103 self.memory
104 .read()
105 .expect("projection registry lock")
106 .get(key)
107 .cloned()
108 .ok_or_else(|| ProjectionResolveError::unavailable(reference))
109 }
110}
111
112#[derive(Clone)]
113enum RlmProjectedBinding {
114 Value(FlowValue),
115 Lazy(ProjectionRef),
116}
117
118#[derive(Clone, Default)]
119pub struct RlmProjectedBindings {
120 bindings: BTreeMap<String, RlmProjectedBinding>,
121}
122
123pub type RlmToolResultProjector =
124 Arc<dyn Fn(&str, &serde_json::Value) -> Option<FlowValue> + Send + Sync + 'static>;
125
126#[derive(Clone, Debug, PartialEq, Eq)]
127pub enum RlmProjectedSeedError {
128 Binding(ProjectedBindingError),
129 InvalidProjectionRef { name: String, source: String },
130}
131
132impl RlmProjectedSeedError {
133 pub fn invalid_projection_ref(name: impl Into<String>, source: impl std::fmt::Display) -> Self {
134 Self::InvalidProjectionRef {
135 name: name.into(),
136 source: source.to_string(),
137 }
138 }
139}
140
141impl From<ProjectedBindingError> for RlmProjectedSeedError {
142 fn from(value: ProjectedBindingError) -> Self {
143 Self::Binding(value)
144 }
145}
146
147impl std::fmt::Display for RlmProjectedSeedError {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Binding(err) => err.fmt(f),
151 Self::InvalidProjectionRef { name, source } => {
152 write!(
153 f,
154 "invalid projection ref for projected seed `{name}`: {source}"
155 )
156 }
157 }
158 }
159}
160
161impl std::error::Error for RlmProjectedSeedError {}
162
163impl RlmProjectedBindings {
164 pub fn new() -> Self {
165 Self::default()
166 }
167
168 pub fn bind_value(
169 mut self,
170 name: impl Into<String>,
171 value: impl Into<FlowValue>,
172 ) -> Result<Self, ProjectedBindingError> {
173 let name = name.into();
174 if self.bindings.contains_key(&name) {
175 return Err(ProjectedBindingError::duplicate(name));
176 }
177 self.bindings
178 .insert(name, RlmProjectedBinding::Value(value.into()));
179 Ok(self)
180 }
181
182 pub fn bind_json(
183 self,
184 name: impl Into<String>,
185 value: serde_json::Value,
186 ) -> Result<Self, ProjectedBindingError> {
187 self.bind_value(name, lashlang::from_json(value))
188 }
189
190 pub fn bind_lazy(
191 mut self,
192 name: impl Into<String>,
193 reference: ProjectionRef,
194 ) -> Result<Self, ProjectedBindingError> {
195 let name = name.into();
196 if self.bindings.contains_key(&name) {
197 return Err(ProjectedBindingError::duplicate(name));
198 }
199 self.bindings
200 .insert(name, RlmProjectedBinding::Lazy(reference));
201 Ok(self)
202 }
203
204 pub fn names(&self) -> impl Iterator<Item = String> + '_ {
205 self.bindings.keys().cloned()
206 }
207
208 pub(crate) async fn into_projected_bindings(
209 self,
210 resolver: Arc<dyn ProjectionResolver>,
211 ) -> Result<ProjectedBindings, ProjectionResolveError> {
212 let mut out = ProjectedBindings::new();
213 for (name, binding) in self.bindings {
214 let value = match binding {
215 RlmProjectedBinding::Value(value) => ProjectedValue::scalar(name.clone(), value),
216 RlmProjectedBinding::Lazy(reference) => {
217 let resolved = resolver.resolve_projection(&reference).await?;
218 let ref_json = serde_json::to_value(&reference).map_err(|err| {
219 ProjectionResolveError::invalid(format!(
220 "projection ref did not serialize: {err}"
221 ))
222 })?;
223 ProjectedValue::custom_with_projection_ref(name.clone(), resolved, ref_json)
224 }
225 };
226 out.try_insert(name, value)
227 .expect("RLM projected bindings already reject duplicates");
228 }
229 Ok(out)
230 }
231
232 pub fn merge(mut self, other: Self) -> Result<Self, ProjectedBindingError> {
233 for (name, value) in other.bindings {
234 if self.bindings.contains_key(&name) {
235 return Err(ProjectedBindingError::duplicate(name));
236 }
237 self.bindings.insert(name, value);
238 }
239 Ok(self)
240 }
241
242 pub fn from_snapshot(
247 snapshot: &lash_rlm_types::RlmProjectedSeedSnapshot,
248 ) -> Result<Self, RlmProjectedSeedError> {
249 let mut out = Self::new();
250 for (name, value) in &snapshot.entries {
251 out = if let Some(reference) =
252 super::transport::projection_ref_from_seed_value(name, value)?
253 {
254 out.bind_lazy(name.clone(), reference)?
255 } else {
256 out.bind_json(name.clone(), value.clone())?
257 };
258 }
259 Ok(out)
260 }
261}
262
263#[derive(Clone, Default)]
264pub(crate) struct RlmProjectionExtension {
265 pub(crate) bindings: RlmProjectedBindings,
266 pub(crate) tool_result_projectors: Vec<RlmToolResultProjector>,
267}
268
269impl RlmProjectionExtension {
270 pub(crate) fn new(bindings: RlmProjectedBindings) -> Self {
271 Self {
272 bindings,
273 tool_result_projectors: Vec::new(),
274 }
275 }
276
277 pub(crate) fn with_projector(projector: RlmToolResultProjector) -> Self {
278 Self {
279 bindings: RlmProjectedBindings::new(),
280 tool_result_projectors: vec![projector],
281 }
282 }
283
284 fn merge(mut self, other: Self) -> Result<Self, ProjectedBindingError> {
285 self.bindings = self.bindings.merge(other.bindings)?;
286 self.tool_result_projectors
287 .extend(other.tool_result_projectors);
288 Ok(self)
289 }
290
291 pub(crate) fn prompt_contributions_for(
292 bindings: &RlmProjectedBindings,
293 ) -> Vec<PromptContribution> {
294 let mut names = bindings.names().collect::<Vec<_>>();
295 if names.is_empty() {
296 return Vec::new();
297 }
298 names.sort();
299 let mut lines = vec![
300 "These read-only values are already in scope. Access them directly in fenced `lashlang` code; do not recreate them manually.".to_string(),
301 String::new(),
302 "Read-only variables:".to_string(),
303 ];
304 for name in names {
305 lines.push(format!("- `{name}`: read-only value"));
306 }
307 vec![PromptContribution::environment(
308 "Read-Only Variables",
309 lines.join("\n"),
310 )]
311 }
312}
313
314impl ProtocolTurnExtension for RlmProjectionExtension {
315 fn as_any(&self) -> &dyn Any {
316 self
317 }
318
319 fn prompt_contributions(&self) -> Vec<PromptContribution> {
320 Self::prompt_contributions_for(&self.bindings)
321 }
322}
323
324impl ProtocolSessionExtension for RlmProjectionExtension {
325 fn as_any(&self) -> &dyn Any {
326 self
327 }
328}
329
330pub fn rlm_session_projection_extension(
331 bindings: RlmProjectedBindings,
332) -> lash_core::ProtocolSessionExtensionHandle {
333 lash_core::ProtocolSessionExtensionHandle::new(RlmProjectionExtension::new(bindings))
334}
335
336pub trait RlmTurnInputExt {
337 fn rlm_project(self, bindings: RlmProjectedBindings) -> Result<Self, ProjectedBindingError>
338 where
339 Self: Sized;
340
341 fn rlm_project_tool_results(
342 self,
343 projector: RlmToolResultProjector,
344 ) -> Result<Self, ProjectedBindingError>
345 where
346 Self: Sized;
347}
348
349impl RlmTurnInputExt for TurnInput {
350 fn rlm_project(
351 mut self,
352 bindings: RlmProjectedBindings,
353 ) -> Result<Self, ProjectedBindingError> {
354 let extension = if let Some(existing) = self
355 .turn_context
356 .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
357 .cloned()
358 {
359 existing
360 .clone()
361 .merge(RlmProjectionExtension::new(bindings))?
362 } else {
363 RlmProjectionExtension::new(bindings)
364 };
365 self.turn_context
366 .insert_plugin_input(RLM_TURN_INPUT_PLUGIN_ID, extension);
367 self.protocol_extension = Some(ProtocolTurnExtensionHandle::new(
368 RlmProjectionExtension::new(
369 self.turn_context
370 .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
371 .expect("RLM projection was just inserted")
372 .bindings
373 .clone(),
374 ),
375 ));
376 Ok(self)
377 }
378
379 fn rlm_project_tool_results(
380 mut self,
381 projector: RlmToolResultProjector,
382 ) -> Result<Self, ProjectedBindingError> {
383 let extension = if let Some(existing) = self
384 .turn_context
385 .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
386 .cloned()
387 {
388 existing
389 .clone()
390 .merge(RlmProjectionExtension::with_projector(projector))?
391 } else {
392 RlmProjectionExtension::with_projector(projector)
393 };
394 self.turn_context
395 .insert_plugin_input(RLM_TURN_INPUT_PLUGIN_ID, extension);
396 self.protocol_extension = Some(ProtocolTurnExtensionHandle::new(
397 RlmProjectionExtension::new(
398 self.turn_context
399 .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
400 .expect("RLM projection was just inserted")
401 .bindings
402 .clone(),
403 ),
404 ));
405 Ok(self)
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use lashlang::{ProjectedFuture, ProjectedReadRequest, ProjectedReadResponse};
413
414 struct TestProjectedValue;
415
416 impl ProjectedHostValue for TestProjectedValue {
417 fn type_name(&self) -> &str {
418 "string"
419 }
420
421 fn read_one(
422 &self,
423 request: ProjectedReadRequest,
424 ) -> ProjectedFuture<'_, ProjectedReadResponse> {
425 Box::pin(async move {
426 match request {
427 ProjectedReadRequest::Materialize => {
428 ProjectedReadResponse::Value(FlowValue::String("lazy".into()))
429 }
430 ProjectedReadRequest::Render => ProjectedReadResponse::Text("lazy".into()),
431 _ => ProjectedReadResponse::Missing,
432 }
433 })
434 }
435 }
436
437 #[test]
438 fn bind_rejects_duplicate_names() {
439 let duplicate = RlmProjectedBindings::new()
440 .bind_json("current_query", serde_json::json!("first"))
441 .expect("first bind")
442 .bind_json("current_query", serde_json::json!("second"));
443 let Err(err) = duplicate else {
444 panic!("duplicate bind should fail");
445 };
446 assert_eq!(err.name(), "current_query");
447 }
448
449 #[test]
450 fn merge_rejects_session_turn_duplicates() {
451 let session = RlmProjectedBindings::new()
452 .bind_json("current_query", serde_json::json!("session"))
453 .expect("session bind");
454 let turn = RlmProjectedBindings::new()
455 .bind_json("current_query", serde_json::json!("turn"))
456 .expect("turn bind");
457 let duplicate = session.merge(turn);
458 let Err(err) = duplicate else {
459 panic!("duplicate session and turn binding should fail");
460 };
461 assert_eq!(err.name(), "current_query");
462 }
463
464 #[tokio::test]
465 async fn bind_lazy_resolves_memory_projection_ref() {
466 let registry = Arc::new(ProjectionRegistry::new());
467 let reference = registry.register_memory(Arc::new(TestProjectedValue));
468 let bindings = RlmProjectedBindings::new()
469 .bind_lazy("doc", reference.clone())
470 .expect("lazy bind");
471
472 let projected = bindings
473 .into_projected_bindings(registry)
474 .await
475 .expect("resolve projected bindings");
476 let value = projected.get("doc").expect("doc binding");
477 assert_eq!(value.projection_ref(), Some(&serde_json::json!(reference)));
478 assert_eq!(value.render().await, "lazy");
479 }
480
481 #[tokio::test]
482 async fn bind_lazy_reports_missing_memory_projection_ref() {
483 let registry = Arc::new(ProjectionRegistry::new());
484 let reference = ProjectionRef::new("memory", serde_json::json!("missing"));
485 let bindings = RlmProjectedBindings::new()
486 .bind_lazy("doc", reference)
487 .expect("lazy bind");
488
489 let err = match bindings.into_projected_bindings(registry).await {
490 Ok(_) => panic!("missing ref should fail"),
491 Err(err) => err,
492 };
493 assert!(err.to_string().contains("projection ref unavailable"));
494 }
495
496 #[test]
497 fn projected_seed_snapshot_preserves_projection_refs() {
498 let reference = ProjectionRef::new("memory", serde_json::json!("stable"));
499 let mut snapshot = lash_rlm_types::RlmProjectedSeedSnapshot::new();
500 snapshot.push(
501 "doc",
502 serde_json::json!({
503 lash_rlm_types::PROJECTION_REF_JSON_TAG: reference,
504 }),
505 );
506
507 let bindings = RlmProjectedBindings::from_snapshot(&snapshot).expect("snapshot");
508 assert_eq!(
509 bindings.names().collect::<Vec<_>>(),
510 vec!["doc".to_string()]
511 );
512 }
513
514 #[test]
515 fn projected_seed_snapshot_reports_invalid_projection_refs() {
516 let mut snapshot = lash_rlm_types::RlmProjectedSeedSnapshot::new();
517 snapshot.push(
518 "doc",
519 serde_json::json!({
520 lash_rlm_types::PROJECTION_REF_JSON_TAG: "not a projection ref",
521 }),
522 );
523
524 let err = match RlmProjectedBindings::from_snapshot(&snapshot) {
525 Ok(_) => panic!("invalid projection ref should fail"),
526 Err(err) => err,
527 };
528
529 assert!(err.to_string().contains("invalid projection ref"));
530 assert!(err.to_string().contains("doc"));
531 }
532
533 #[test]
534 fn turn_input_extension_attaches_prompt_contribution() {
535 let input = TurnInput {
536 items: Vec::new(),
537 image_blobs: Default::default(),
538 protocol_turn_options: None,
539 trace_turn_id: None,
540 protocol_extension: None,
541 turn_context: lash_core::TurnContext::default(),
542 }
543 .rlm_project(
544 RlmProjectedBindings::new()
545 .bind_json("current_file", serde_json::json!("src/lib.rs"))
546 .expect("bind"),
547 )
548 .expect("attach");
549 let contribution = input
550 .protocol_extension
551 .expect("extension")
552 .prompt_contributions()
553 .pop()
554 .expect("prompt contribution");
555 assert!(contribution.content.contains("`current_file`"));
556 assert!(contribution.content.contains("read-only value"));
557 }
558
559 #[test]
560 fn turn_input_extension_is_skipped_by_serde() {
561 let input = TurnInput {
562 items: Vec::new(),
563 image_blobs: Default::default(),
564 protocol_turn_options: None,
565 trace_turn_id: Some("stable".to_string()),
566 protocol_extension: None,
567 turn_context: lash_core::TurnContext::default(),
568 }
569 .rlm_project(
570 RlmProjectedBindings::new()
571 .bind_json("current_file", serde_json::json!("src/lib.rs"))
572 .expect("bind"),
573 )
574 .expect("attach");
575
576 let encoded = serde_json::to_string(&input).expect("serialize");
577 assert!(!encoded.contains("protocol_extension"));
578 assert!(!encoded.contains("current_file"));
579 let decoded: TurnInput = serde_json::from_str(&encoded).expect("deserialize");
580 assert!(decoded.protocol_extension.is_none());
581 assert_eq!(decoded.trace_turn_id.as_deref(), Some("stable"));
582 }
583
584 #[test]
585 fn matching_trace_turn_ids_do_not_share_projection_extensions() {
586 let first = TurnInput {
587 items: Vec::new(),
588 image_blobs: Default::default(),
589 protocol_turn_options: None,
590 trace_turn_id: Some("same-trace".to_string()),
591 protocol_extension: None,
592 turn_context: lash_core::TurnContext::default(),
593 }
594 .rlm_project(
595 RlmProjectedBindings::new()
596 .bind_json("first_name", serde_json::json!("first"))
597 .expect("bind"),
598 )
599 .expect("attach first");
600 let second = TurnInput {
601 items: Vec::new(),
602 image_blobs: Default::default(),
603 protocol_turn_options: None,
604 trace_turn_id: Some("same-trace".to_string()),
605 protocol_extension: None,
606 turn_context: lash_core::TurnContext::default(),
607 }
608 .rlm_project(
609 RlmProjectedBindings::new()
610 .bind_json("second_name", serde_json::json!("second"))
611 .expect("bind"),
612 )
613 .expect("attach second");
614
615 let first_extension = first
616 .protocol_extension
617 .as_ref()
618 .and_then(|extension| extension.as_any().downcast_ref::<RlmProjectionExtension>())
619 .expect("first extension");
620 let second_extension = second
621 .protocol_extension
622 .as_ref()
623 .and_then(|extension| extension.as_any().downcast_ref::<RlmProjectionExtension>())
624 .expect("second extension");
625 assert_eq!(
626 first_extension.bindings.names().collect::<Vec<_>>(),
627 vec!["first_name".to_string()]
628 );
629 assert_eq!(
630 second_extension.bindings.names().collect::<Vec<_>>(),
631 vec!["second_name".to_string()]
632 );
633 }
634}