Skip to main content

lash_protocol_rlm/projection/
bindings.rs

1use std::any::Any;
2use std::collections::BTreeMap;
3use std::sync::Arc;
4
5use lash_core::{
6    PromptContribution, ProtocolSessionExtension, ProtocolTurnExtension,
7    ProtocolTurnExtensionHandle, TurnInput,
8};
9
10pub(crate) const RLM_TURN_INPUT_PLUGIN_ID: &str = "rlm";
11use lashlang::{
12    ProjectedBindingError, ProjectedBindings, ProjectedHostValue, ProjectedValue,
13    Value as FlowValue,
14};
15
16#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
17pub struct ProjectionRef {
18    pub kind: String,
19    pub key: serde_json::Value,
20}
21
22impl ProjectionRef {
23    pub fn new(kind: impl Into<String>, key: serde_json::Value) -> Self {
24        Self {
25            kind: kind.into(),
26            key,
27        }
28    }
29}
30
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct ProjectionResolveError {
33    message: String,
34}
35
36impl ProjectionResolveError {
37    pub fn unavailable(reference: &ProjectionRef) -> Self {
38        Self {
39            message: format!(
40                "projection ref unavailable: kind `{}`, key {}",
41                reference.kind, reference.key
42            ),
43        }
44    }
45
46    pub fn invalid(message: impl Into<String>) -> Self {
47        Self {
48            message: message.into(),
49        }
50    }
51}
52
53impl std::fmt::Display for ProjectionResolveError {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.write_str(&self.message)
56    }
57}
58
59impl std::error::Error for ProjectionResolveError {}
60
61#[async_trait::async_trait]
62pub trait ProjectionResolver: Send + Sync {
63    async fn resolve_projection(
64        &self,
65        reference: &ProjectionRef,
66    ) -> Result<Arc<dyn ProjectedHostValue>, ProjectionResolveError>;
67}
68
69#[derive(Clone, Default)]
70pub struct ProjectionRegistry {
71    memory: Arc<std::sync::RwLock<BTreeMap<String, Arc<dyn ProjectedHostValue>>>>,
72}
73
74impl ProjectionRegistry {
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    pub fn register_memory(&self, value: Arc<dyn ProjectedHostValue>) -> ProjectionRef {
80        let key = uuid::Uuid::new_v4().to_string();
81        self.memory
82            .write()
83            .expect("projection registry lock")
84            .insert(key.clone(), value);
85        ProjectionRef::new("memory", serde_json::Value::String(key))
86    }
87}
88
89#[async_trait::async_trait]
90impl ProjectionResolver for ProjectionRegistry {
91    async fn resolve_projection(
92        &self,
93        reference: &ProjectionRef,
94    ) -> Result<Arc<dyn ProjectedHostValue>, ProjectionResolveError> {
95        if reference.kind != "memory" {
96            return Err(ProjectionResolveError::unavailable(reference));
97        }
98        let Some(key) = reference.key.as_str() else {
99            return Err(ProjectionResolveError::invalid(
100                "memory projection ref key must be a string",
101            ));
102        };
103        self.memory
104            .read()
105            .expect("projection registry lock")
106            .get(key)
107            .cloned()
108            .ok_or_else(|| ProjectionResolveError::unavailable(reference))
109    }
110}
111
112#[derive(Clone)]
113enum RlmProjectedBinding {
114    Value(FlowValue),
115    Lazy(ProjectionRef),
116}
117
118#[derive(Clone, Default)]
119pub struct RlmProjectedBindings {
120    bindings: BTreeMap<String, RlmProjectedBinding>,
121}
122
123pub type RlmToolResultProjector =
124    Arc<dyn Fn(&str, &serde_json::Value) -> Option<FlowValue> + Send + Sync + 'static>;
125
126#[derive(Clone, Debug, PartialEq, Eq)]
127pub enum RlmProjectedSeedError {
128    Binding(ProjectedBindingError),
129    InvalidProjectionRef { name: String, source: String },
130}
131
132impl RlmProjectedSeedError {
133    pub fn invalid_projection_ref(name: impl Into<String>, source: impl std::fmt::Display) -> Self {
134        Self::InvalidProjectionRef {
135            name: name.into(),
136            source: source.to_string(),
137        }
138    }
139}
140
141impl From<ProjectedBindingError> for RlmProjectedSeedError {
142    fn from(value: ProjectedBindingError) -> Self {
143        Self::Binding(value)
144    }
145}
146
147impl std::fmt::Display for RlmProjectedSeedError {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Self::Binding(err) => err.fmt(f),
151            Self::InvalidProjectionRef { name, source } => {
152                write!(
153                    f,
154                    "invalid projection ref for projected seed `{name}`: {source}"
155                )
156            }
157        }
158    }
159}
160
161impl std::error::Error for RlmProjectedSeedError {}
162
163impl RlmProjectedBindings {
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    pub fn bind_value(
169        mut self,
170        name: impl Into<String>,
171        value: impl Into<FlowValue>,
172    ) -> Result<Self, ProjectedBindingError> {
173        let name = name.into();
174        if self.bindings.contains_key(&name) {
175            return Err(ProjectedBindingError::duplicate(name));
176        }
177        self.bindings
178            .insert(name, RlmProjectedBinding::Value(value.into()));
179        Ok(self)
180    }
181
182    pub fn bind_json(
183        self,
184        name: impl Into<String>,
185        value: serde_json::Value,
186    ) -> Result<Self, ProjectedBindingError> {
187        self.bind_value(name, lashlang::from_json(value))
188    }
189
190    pub fn bind_lazy(
191        mut self,
192        name: impl Into<String>,
193        reference: ProjectionRef,
194    ) -> Result<Self, ProjectedBindingError> {
195        let name = name.into();
196        if self.bindings.contains_key(&name) {
197            return Err(ProjectedBindingError::duplicate(name));
198        }
199        self.bindings
200            .insert(name, RlmProjectedBinding::Lazy(reference));
201        Ok(self)
202    }
203
204    pub fn names(&self) -> impl Iterator<Item = String> + '_ {
205        self.bindings.keys().cloned()
206    }
207
208    pub(crate) async fn into_projected_bindings(
209        self,
210        resolver: Arc<dyn ProjectionResolver>,
211    ) -> Result<ProjectedBindings, ProjectionResolveError> {
212        let mut out = ProjectedBindings::new();
213        for (name, binding) in self.bindings {
214            let value = match binding {
215                RlmProjectedBinding::Value(value) => ProjectedValue::scalar(name.clone(), value),
216                RlmProjectedBinding::Lazy(reference) => {
217                    let resolved = resolver.resolve_projection(&reference).await?;
218                    let ref_json = serde_json::to_value(&reference).map_err(|err| {
219                        ProjectionResolveError::invalid(format!(
220                            "projection ref did not serialize: {err}"
221                        ))
222                    })?;
223                    ProjectedValue::custom_with_projection_ref(name.clone(), resolved, ref_json)
224                }
225            };
226            out.try_insert(name, value)
227                .expect("RLM projected bindings already reject duplicates");
228        }
229        Ok(out)
230    }
231
232    pub fn merge(mut self, other: Self) -> Result<Self, ProjectedBindingError> {
233        for (name, value) in other.bindings {
234            if self.bindings.contains_key(&name) {
235                return Err(ProjectedBindingError::duplicate(name));
236            }
237            self.bindings.insert(name, value);
238        }
239        Ok(self)
240    }
241
242    /// Hydrate from a wire-format `RlmProjectedSeedSnapshot`. Each entry is
243    /// re-projected via `bind_json`. Used by the RLM protocol to seed projections on a
244    /// child session (spawn_agent / continue_as) from the parent's classified
245    /// seed map.
246    pub fn from_snapshot(
247        snapshot: &lash_rlm_types::RlmProjectedSeedSnapshot,
248    ) -> Result<Self, RlmProjectedSeedError> {
249        let mut out = Self::new();
250        for (name, value) in &snapshot.entries {
251            out = if let Some(reference) =
252                super::transport::projection_ref_from_seed_value(name, value)?
253            {
254                out.bind_lazy(name.clone(), reference)?
255            } else {
256                out.bind_json(name.clone(), value.clone())?
257            };
258        }
259        Ok(out)
260    }
261}
262
263#[derive(Clone, Default)]
264pub(crate) struct RlmProjectionExtension {
265    pub(crate) bindings: RlmProjectedBindings,
266    pub(crate) tool_result_projectors: Vec<RlmToolResultProjector>,
267}
268
269impl RlmProjectionExtension {
270    pub(crate) fn new(bindings: RlmProjectedBindings) -> Self {
271        Self {
272            bindings,
273            tool_result_projectors: Vec::new(),
274        }
275    }
276
277    pub(crate) fn with_projector(projector: RlmToolResultProjector) -> Self {
278        Self {
279            bindings: RlmProjectedBindings::new(),
280            tool_result_projectors: vec![projector],
281        }
282    }
283
284    fn merge(mut self, other: Self) -> Result<Self, ProjectedBindingError> {
285        self.bindings = self.bindings.merge(other.bindings)?;
286        self.tool_result_projectors
287            .extend(other.tool_result_projectors);
288        Ok(self)
289    }
290
291    pub(crate) fn prompt_contributions_for(
292        bindings: &RlmProjectedBindings,
293    ) -> Vec<PromptContribution> {
294        let mut names = bindings.names().collect::<Vec<_>>();
295        if names.is_empty() {
296            return Vec::new();
297        }
298        names.sort();
299        let mut lines = vec![
300            "These read-only values are already in scope. Access them directly in fenced `lashlang` code; do not recreate them manually.".to_string(),
301            String::new(),
302            "Read-only variables:".to_string(),
303        ];
304        for name in names {
305            lines.push(format!("- `{name}`: read-only value"));
306        }
307        vec![PromptContribution::environment(
308            "Read-Only Variables",
309            lines.join("\n"),
310        )]
311    }
312}
313
314impl ProtocolTurnExtension for RlmProjectionExtension {
315    fn as_any(&self) -> &dyn Any {
316        self
317    }
318
319    fn prompt_contributions(&self) -> Vec<PromptContribution> {
320        Self::prompt_contributions_for(&self.bindings)
321    }
322}
323
324impl ProtocolSessionExtension for RlmProjectionExtension {
325    fn as_any(&self) -> &dyn Any {
326        self
327    }
328}
329
330pub fn rlm_session_projection_extension(
331    bindings: RlmProjectedBindings,
332) -> lash_core::ProtocolSessionExtensionHandle {
333    lash_core::ProtocolSessionExtensionHandle::new(RlmProjectionExtension::new(bindings))
334}
335
336pub trait RlmTurnInputExt {
337    fn rlm_project(self, bindings: RlmProjectedBindings) -> Result<Self, ProjectedBindingError>
338    where
339        Self: Sized;
340
341    fn rlm_project_tool_results(
342        self,
343        projector: RlmToolResultProjector,
344    ) -> Result<Self, ProjectedBindingError>
345    where
346        Self: Sized;
347}
348
349impl RlmTurnInputExt for TurnInput {
350    fn rlm_project(
351        mut self,
352        bindings: RlmProjectedBindings,
353    ) -> Result<Self, ProjectedBindingError> {
354        let extension = if let Some(existing) = self
355            .turn_context
356            .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
357            .cloned()
358        {
359            existing
360                .clone()
361                .merge(RlmProjectionExtension::new(bindings))?
362        } else {
363            RlmProjectionExtension::new(bindings)
364        };
365        self.turn_context
366            .insert_plugin_input(RLM_TURN_INPUT_PLUGIN_ID, extension);
367        self.protocol_extension = Some(ProtocolTurnExtensionHandle::new(
368            RlmProjectionExtension::new(
369                self.turn_context
370                    .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
371                    .expect("RLM projection was just inserted")
372                    .bindings
373                    .clone(),
374            ),
375        ));
376        Ok(self)
377    }
378
379    fn rlm_project_tool_results(
380        mut self,
381        projector: RlmToolResultProjector,
382    ) -> Result<Self, ProjectedBindingError> {
383        let extension = if let Some(existing) = self
384            .turn_context
385            .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
386            .cloned()
387        {
388            existing
389                .clone()
390                .merge(RlmProjectionExtension::with_projector(projector))?
391        } else {
392            RlmProjectionExtension::with_projector(projector)
393        };
394        self.turn_context
395            .insert_plugin_input(RLM_TURN_INPUT_PLUGIN_ID, extension);
396        self.protocol_extension = Some(ProtocolTurnExtensionHandle::new(
397            RlmProjectionExtension::new(
398                self.turn_context
399                    .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
400                    .expect("RLM projection was just inserted")
401                    .bindings
402                    .clone(),
403            ),
404        ));
405        Ok(self)
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use lashlang::{ProjectedFuture, ProjectedReadRequest, ProjectedReadResponse};
413
414    struct TestProjectedValue;
415
416    impl ProjectedHostValue for TestProjectedValue {
417        fn type_name(&self) -> &str {
418            "string"
419        }
420
421        fn read_one(
422            &self,
423            request: ProjectedReadRequest,
424        ) -> ProjectedFuture<'_, ProjectedReadResponse> {
425            Box::pin(async move {
426                match request {
427                    ProjectedReadRequest::Materialize => {
428                        ProjectedReadResponse::Value(FlowValue::String("lazy".into()))
429                    }
430                    ProjectedReadRequest::Render => ProjectedReadResponse::Text("lazy".into()),
431                    _ => ProjectedReadResponse::Missing,
432                }
433            })
434        }
435    }
436
437    #[test]
438    fn bind_rejects_duplicate_names() {
439        let duplicate = RlmProjectedBindings::new()
440            .bind_json("current_query", serde_json::json!("first"))
441            .expect("first bind")
442            .bind_json("current_query", serde_json::json!("second"));
443        let Err(err) = duplicate else {
444            panic!("duplicate bind should fail");
445        };
446        assert_eq!(err.name(), "current_query");
447    }
448
449    #[test]
450    fn merge_rejects_session_turn_duplicates() {
451        let session = RlmProjectedBindings::new()
452            .bind_json("current_query", serde_json::json!("session"))
453            .expect("session bind");
454        let turn = RlmProjectedBindings::new()
455            .bind_json("current_query", serde_json::json!("turn"))
456            .expect("turn bind");
457        let duplicate = session.merge(turn);
458        let Err(err) = duplicate else {
459            panic!("duplicate session and turn binding should fail");
460        };
461        assert_eq!(err.name(), "current_query");
462    }
463
464    #[tokio::test]
465    async fn bind_lazy_resolves_memory_projection_ref() {
466        let registry = Arc::new(ProjectionRegistry::new());
467        let reference = registry.register_memory(Arc::new(TestProjectedValue));
468        let bindings = RlmProjectedBindings::new()
469            .bind_lazy("doc", reference.clone())
470            .expect("lazy bind");
471
472        let projected = bindings
473            .into_projected_bindings(registry)
474            .await
475            .expect("resolve projected bindings");
476        let value = projected.get("doc").expect("doc binding");
477        assert_eq!(value.projection_ref(), Some(&serde_json::json!(reference)));
478        assert_eq!(value.render().await, "lazy");
479    }
480
481    #[tokio::test]
482    async fn bind_lazy_reports_missing_memory_projection_ref() {
483        let registry = Arc::new(ProjectionRegistry::new());
484        let reference = ProjectionRef::new("memory", serde_json::json!("missing"));
485        let bindings = RlmProjectedBindings::new()
486            .bind_lazy("doc", reference)
487            .expect("lazy bind");
488
489        let err = match bindings.into_projected_bindings(registry).await {
490            Ok(_) => panic!("missing ref should fail"),
491            Err(err) => err,
492        };
493        assert!(err.to_string().contains("projection ref unavailable"));
494    }
495
496    #[test]
497    fn projected_seed_snapshot_preserves_projection_refs() {
498        let reference = ProjectionRef::new("memory", serde_json::json!("stable"));
499        let mut snapshot = lash_rlm_types::RlmProjectedSeedSnapshot::new();
500        snapshot.push(
501            "doc",
502            serde_json::json!({
503                lash_rlm_types::PROJECTION_REF_JSON_TAG: reference,
504            }),
505        );
506
507        let bindings = RlmProjectedBindings::from_snapshot(&snapshot).expect("snapshot");
508        assert_eq!(
509            bindings.names().collect::<Vec<_>>(),
510            vec!["doc".to_string()]
511        );
512    }
513
514    #[test]
515    fn projected_seed_snapshot_reports_invalid_projection_refs() {
516        let mut snapshot = lash_rlm_types::RlmProjectedSeedSnapshot::new();
517        snapshot.push(
518            "doc",
519            serde_json::json!({
520                lash_rlm_types::PROJECTION_REF_JSON_TAG: "not a projection ref",
521            }),
522        );
523
524        let err = match RlmProjectedBindings::from_snapshot(&snapshot) {
525            Ok(_) => panic!("invalid projection ref should fail"),
526            Err(err) => err,
527        };
528
529        assert!(err.to_string().contains("invalid projection ref"));
530        assert!(err.to_string().contains("doc"));
531    }
532
533    #[test]
534    fn turn_input_extension_attaches_prompt_contribution() {
535        let input = TurnInput {
536            items: Vec::new(),
537            image_blobs: Default::default(),
538            protocol_turn_options: None,
539            trace_turn_id: None,
540            protocol_extension: None,
541            turn_context: lash_core::TurnContext::default(),
542        }
543        .rlm_project(
544            RlmProjectedBindings::new()
545                .bind_json("current_file", serde_json::json!("src/lib.rs"))
546                .expect("bind"),
547        )
548        .expect("attach");
549        let contribution = input
550            .protocol_extension
551            .expect("extension")
552            .prompt_contributions()
553            .pop()
554            .expect("prompt contribution");
555        assert!(contribution.content.contains("`current_file`"));
556        assert!(contribution.content.contains("read-only value"));
557    }
558
559    #[test]
560    fn turn_input_extension_is_skipped_by_serde() {
561        let input = TurnInput {
562            items: Vec::new(),
563            image_blobs: Default::default(),
564            protocol_turn_options: None,
565            trace_turn_id: Some("stable".to_string()),
566            protocol_extension: None,
567            turn_context: lash_core::TurnContext::default(),
568        }
569        .rlm_project(
570            RlmProjectedBindings::new()
571                .bind_json("current_file", serde_json::json!("src/lib.rs"))
572                .expect("bind"),
573        )
574        .expect("attach");
575
576        let encoded = serde_json::to_string(&input).expect("serialize");
577        assert!(!encoded.contains("protocol_extension"));
578        assert!(!encoded.contains("current_file"));
579        let decoded: TurnInput = serde_json::from_str(&encoded).expect("deserialize");
580        assert!(decoded.protocol_extension.is_none());
581        assert_eq!(decoded.trace_turn_id.as_deref(), Some("stable"));
582    }
583
584    #[test]
585    fn matching_trace_turn_ids_do_not_share_projection_extensions() {
586        let first = TurnInput {
587            items: Vec::new(),
588            image_blobs: Default::default(),
589            protocol_turn_options: None,
590            trace_turn_id: Some("same-trace".to_string()),
591            protocol_extension: None,
592            turn_context: lash_core::TurnContext::default(),
593        }
594        .rlm_project(
595            RlmProjectedBindings::new()
596                .bind_json("first_name", serde_json::json!("first"))
597                .expect("bind"),
598        )
599        .expect("attach first");
600        let second = TurnInput {
601            items: Vec::new(),
602            image_blobs: Default::default(),
603            protocol_turn_options: None,
604            trace_turn_id: Some("same-trace".to_string()),
605            protocol_extension: None,
606            turn_context: lash_core::TurnContext::default(),
607        }
608        .rlm_project(
609            RlmProjectedBindings::new()
610                .bind_json("second_name", serde_json::json!("second"))
611                .expect("bind"),
612        )
613        .expect("attach second");
614
615        let first_extension = first
616            .protocol_extension
617            .as_ref()
618            .and_then(|extension| extension.as_any().downcast_ref::<RlmProjectionExtension>())
619            .expect("first extension");
620        let second_extension = second
621            .protocol_extension
622            .as_ref()
623            .and_then(|extension| extension.as_any().downcast_ref::<RlmProjectionExtension>())
624            .expect("second extension");
625        assert_eq!(
626            first_extension.bindings.names().collect::<Vec<_>>(),
627            vec!["first_name".to_string()]
628        );
629        assert_eq!(
630            second_extension.bindings.names().collect::<Vec<_>>(),
631            vec!["second_name".to_string()]
632        );
633    }
634}