Skip to main content

lash_protocol_rlm/projection/
context.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use lash_core::{ChronologicalPayload, Message, MessageRole, PartKind, RuntimeExecutionContext};
5use lash_rlm_types::{
6    RlmAttachmentRef, RlmHistoryItem, RlmHistoryRole, RlmImageRef, RlmProtocolEvent,
7    RlmTrajectoryEntry,
8};
9use lashlang::{
10    ProjectedBindings, ProjectedFuture, ProjectedHostValue, ProjectedReadRequest,
11    ProjectedReadResponse, ProjectedValue, State as FlowState, Value as FlowValue,
12};
13
14use super::bindings::{
15    ProjectionResolver, RLM_TURN_INPUT_PLUGIN_ID, RlmProjectedBindings, RlmProjectionExtension,
16};
17use super::transport::json_to_flow_value;
18
19pub fn rlm_protocol_event(event: RlmProtocolEvent) -> lash_core::ProtocolEvent {
20    lash_core::ProtocolEvent::typed(crate::plugin::RLM_PROTOCOL_PLUGIN_ID, event)
21        .expect("RLM protocol events serialize")
22}
23
24pub fn decode_rlm_protocol_event(event: &lash_core::ProtocolEvent) -> Option<RlmProtocolEvent> {
25    event
26        .decode(crate::plugin::RLM_PROTOCOL_PLUGIN_ID)
27        .ok()
28        .flatten()
29}
30
31#[derive(Clone, Debug)]
32pub struct RlmHistoryProjection {
33    history: Vec<RlmHistoryItem>,
34}
35
36impl RlmHistoryProjection {
37    pub fn from_chronological(projection: &lash_core::ChronologicalProjection) -> Self {
38        let mut history = Vec::with_capacity(projection.entries().len());
39        for entry in projection.entries() {
40            match &entry.payload {
41                ChronologicalPayload::Message(message) => {
42                    if let Some(item) = history_item_from_message(message) {
43                        history.push(item);
44                    }
45                }
46                ChronologicalPayload::ProtocolEvent(event) => {
47                    if let Some(RlmProtocolEvent::RlmTrajectoryEntry(step)) =
48                        decode_rlm_protocol_event(event)
49                    {
50                        history.push(history_item_from_rlm_step(&step));
51                    }
52                }
53            }
54        }
55        Self { history }
56    }
57
58    pub fn history(&self) -> &[RlmHistoryItem] {
59        self.history.as_slice()
60    }
61
62    pub fn len(&self) -> usize {
63        self.history.len()
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.history.is_empty()
68    }
69
70    pub fn item(&self, index: usize) -> Option<RlmHistoryItem> {
71        self.history.get(index).cloned()
72    }
73
74    pub fn value(&self) -> serde_json::Value {
75        serde_json::to_value(&self.history).unwrap_or_else(|_| serde_json::Value::Array(vec![]))
76    }
77}
78
79pub fn rlm_history_projection(
80    projection: &lash_core::ChronologicalProjection,
81) -> RlmHistoryProjection {
82    RlmHistoryProjection::from_chronological(projection)
83}
84
85pub(crate) async fn projected_bindings(
86    ctx: &RuntimeExecutionContext<'_>,
87    session_bindings: RlmProjectedBindings,
88    projection_resolver: Arc<dyn ProjectionResolver>,
89) -> Result<ProjectedBindings, String> {
90    let mut bindings = ProjectedBindings::new();
91    bindings
92        .try_insert(
93            "history",
94            ProjectedValue::custom(
95                "history",
96                Arc::new(HistoryProjectedValue {
97                    projection: Arc::new(rlm_history_projection(
98                        ctx.chronological_projection().as_ref(),
99                    )),
100                }),
101            ),
102        )
103        .map_err(|err| format!("`{}` is reserved as an RLM built-in binding", err.name()))?;
104    insert_projected_bindings(
105        &mut bindings,
106        session_bindings,
107        Arc::clone(&projection_resolver),
108    )
109    .await?;
110    if let Some(extension) = ctx
111        .turn_context()
112        .plugin_input::<RlmProjectionExtension>(RLM_TURN_INPUT_PLUGIN_ID)
113    {
114        insert_projected_bindings(
115            &mut bindings,
116            extension.bindings.clone(),
117            projection_resolver,
118        )
119        .await?;
120    }
121    Ok(bindings)
122}
123
124async fn insert_projected_bindings(
125    target: &mut ProjectedBindings,
126    bindings: RlmProjectedBindings,
127    projection_resolver: Arc<dyn ProjectionResolver>,
128) -> Result<(), String> {
129    let host_bindings = bindings
130        .into_projected_bindings(projection_resolver)
131        .await
132        .map_err(|err| err.to_string())?;
133    for name in host_bindings.names().collect::<Vec<_>>() {
134        let value = host_bindings
135            .get(&name)
136            .expect("name came from projected bindings");
137        target.try_insert(name, value).map_err(|err| {
138            format!(
139                "`{}` is already bound as an RLM projected binding",
140                err.name()
141            )
142        })?;
143    }
144    Ok(())
145}
146
147struct HistoryProjectedValue {
148    projection: Arc<RlmHistoryProjection>,
149}
150
151impl ProjectedHostValue for HistoryProjectedValue {
152    fn type_name(&self) -> &str {
153        "list"
154    }
155
156    fn read_one(
157        &self,
158        request: ProjectedReadRequest,
159    ) -> ProjectedFuture<'_, ProjectedReadResponse> {
160        Box::pin(async move {
161            match request {
162                ProjectedReadRequest::Len => ProjectedReadResponse::Len(self.projection.len()),
163                ProjectedReadRequest::Index(index) => {
164                    let Ok(Some(index)) = projected_index(&index, self.projection.len()) else {
165                        return ProjectedReadResponse::Missing;
166                    };
167                    self.projection
168                        .item(index)
169                        .and_then(|item| serde_json::to_value(item).ok())
170                        .map(json_to_flow_value)
171                        .map(ProjectedReadResponse::Value)
172                        .unwrap_or(ProjectedReadResponse::Missing)
173                }
174                ProjectedReadRequest::Render => ProjectedReadResponse::Text(
175                    serde_json::to_string(self.projection.history())
176                        .unwrap_or_else(|_| "[]".to_string()),
177                ),
178                ProjectedReadRequest::Materialize => {
179                    ProjectedReadResponse::Value(json_to_flow_value(self.projection.value()))
180                }
181                _ => ProjectedReadResponse::Missing,
182            }
183        })
184    }
185}
186
187pub(crate) fn projected_index(index: &FlowValue, len: usize) -> Result<Option<usize>, ()> {
188    let FlowValue::Number(index) = index else {
189        return Err(());
190    };
191    if !index.is_finite() || index.fract() != 0.0 {
192        return Err(());
193    }
194    let len = len as isize;
195    let index = *index as isize;
196    let normalized = if index < 0 { len + index } else { index };
197    if normalized < 0 || normalized >= len {
198        return Ok(None);
199    }
200    Ok(Some(normalized as usize))
201}
202
203pub(crate) fn prune_reserved_projected_bindings(rlm: &mut FlowState) {
204    prune_protected_bindings(rlm, &BTreeSet::new());
205}
206
207pub(crate) fn prune_protected_bindings(rlm: &mut FlowState, protected_names: &BTreeSet<String>) {
208    prune_projected_binding_names(
209        rlm,
210        std::iter::once("history").chain(protected_names.iter().map(String::as_str)),
211    );
212}
213
214pub(crate) fn prune_projected_binding_names<'a>(
215    rlm: &mut FlowState,
216    names: impl IntoIterator<Item = &'a str>,
217) {
218    let mut snapshot = rlm.snapshot();
219    for key in names {
220        snapshot.globals.remove(key);
221    }
222    *rlm = FlowState::from_snapshot(snapshot);
223}
224
225fn history_item_from_message(message: &Message) -> Option<RlmHistoryItem> {
226    Some(RlmHistoryItem::Message {
227        id: message.id.clone(),
228        role: history_role(message.role),
229        content: message_history_text(message),
230        attachments: message
231            .parts
232            .iter()
233            .filter_map(|part| {
234                let attachment = part.attachment.as_ref()?;
235                Some(RlmAttachmentRef {
236                    id: part.id.clone(),
237                    media_type: attachment.reference.media_type,
238                    label: attachment.reference.label.clone(),
239                    reference: attachment.reference.id.to_string(),
240                })
241            })
242            .collect(),
243    })
244}
245
246fn history_item_from_rlm_step(entry: &RlmTrajectoryEntry) -> RlmHistoryItem {
247    RlmHistoryItem::RlmStep {
248        id: entry.id.clone(),
249        protocol_iteration: entry.protocol_iteration,
250        reasoning: entry.reasoning.clone(),
251        code: entry.code.clone(),
252        output: entry.output.clone(),
253        images: entry.images.iter().map(image_ref).collect(),
254        error: entry.error.clone(),
255        final_output: entry.final_output.clone(),
256    }
257}
258
259fn message_history_text(message: &Message) -> String {
260    let chunks = message
261        .parts
262        .iter()
263        .filter(|part| matches!(part.kind, PartKind::Text | PartKind::Prose))
264        .map(|part| part.content.trim())
265        .filter(|part| !part.is_empty())
266        .collect::<Vec<_>>();
267    chunks.join("\n\n")
268}
269
270fn history_role(role: MessageRole) -> RlmHistoryRole {
271    match role {
272        MessageRole::User => RlmHistoryRole::User,
273        MessageRole::System => RlmHistoryRole::System,
274        MessageRole::Assistant => RlmHistoryRole::Assistant,
275        MessageRole::Event => RlmHistoryRole::Event,
276    }
277}
278
279fn image_ref(image: &lash_core::AttachmentRef) -> RlmImageRef {
280    RlmImageRef {
281        id: image.id.to_string(),
282        media_type: image.media_type,
283        width: image.width,
284        height: image.height,
285        bytes: image.byte_len as usize,
286        label: image.label.clone(),
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    fn step_projection(output: &str) -> lash_core::ChronologicalProjection {
295        let entry = RlmTrajectoryEntry {
296            id: "rlm_step_0".to_string(),
297            protocol_iteration: 0,
298            reasoning: "thinking".to_string(),
299            code: "print big".to_string(),
300            output: vec![output.to_string()],
301            images: Vec::new(),
302            error: None,
303            final_output: None,
304        };
305        let events = [lash_core::SessionEventRecord::Protocol(rlm_protocol_event(
306            RlmProtocolEvent::RlmTrajectoryEntry(entry),
307        ))];
308        lash_core::ChronologicalProjection::from_turn_view(
309            &events,
310            &lash_core::MessageSequence::default(),
311        )
312    }
313
314    async fn read_index(value: &HistoryProjectedValue, index: i64) -> FlowValue {
315        match value
316            .read_one(ProjectedReadRequest::Index(FlowValue::Number(index as f64)))
317            .await
318        {
319            ProjectedReadResponse::Value(value) => value,
320            other => panic!("expected indexed value, got {other:?}"),
321        }
322    }
323
324    // The history renderer advertises a `full: history[N].output[M]` reference
325    // for truncated outputs (see `truncated_ref` in driver/history.rs). This
326    // proves the contract resolves: indexing the real `history` projection
327    // hands back the FULL, untruncated value the prompt only previewed.
328    #[tokio::test]
329    async fn history_step_output_resolves_full_untruncated_value() {
330        let full = "X".repeat(50_000);
331        let projection = step_projection(&full);
332        let value = HistoryProjectedValue {
333            projection: Arc::new(rlm_history_projection(&projection)),
334        };
335
336        // history[0] -> the serialized RLM step record.
337        let FlowValue::Record(step) = read_index(&value, 0).await else {
338            panic!("history[0] should be a record");
339        };
340        // history[0].output -> the list of per-print outputs.
341        let Some(FlowValue::List(outputs)) = step.get("output") else {
342            panic!("step record should carry an `output` list, got {step:?}");
343        };
344        // history[0].output[0] -> the full untruncated string.
345        let Some(FlowValue::String(text)) = outputs.first() else {
346            panic!("output[0] should be a string");
347        };
348        assert_eq!(
349            text.as_str(),
350            full.as_str(),
351            "re-fetched value must be the full untruncated output"
352        );
353    }
354}