Skip to main content

lash_core/session_model/
mod.rs

1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::ModelSpec;
9use crate::llm::types::{LlmEventSender, LlmStreamEvent};
10use crate::plugin::PluginMessage;
11use crate::provider::{ProviderBinding, ProviderHandle, ProviderResolutionError};
12
13pub use lash_sansio::format_tool_output_content;
14pub use lash_sansio::session_model::{
15    ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
16    PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
17    PruneState, SessionEvent, TokenUsage, TurnTerminationPolicyState, default_prompt_template,
18    make_error_envelope, make_error_event, reassign_part_ids, render_prompt,
19    render_transcript_prompt, shared_parts,
20};
21
22pub fn fresh_message_id() -> String {
23    format!("m{}", uuid::Uuid::new_v4().simple())
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub struct ProtocolEvent {
28    pub plugin_id: String,
29    pub payload: serde_json::Value,
30}
31
32impl ProtocolEvent {
33    pub fn typed<T>(plugin_id: impl Into<String>, event: T) -> Result<Self, serde_json::Error>
34    where
35        T: serde::Serialize,
36    {
37        Ok(Self {
38            plugin_id: plugin_id.into(),
39            payload: serde_json::to_value(event)?,
40        })
41    }
42
43    pub fn decode<T>(&self, expected_plugin_id: &str) -> Result<Option<T>, serde_json::Error>
44    where
45        T: for<'de> serde::Deserialize<'de>,
46    {
47        if self.plugin_id != expected_plugin_id {
48            return Ok(None);
49        }
50        serde_json::from_value(self.payload.clone()).map(Some)
51    }
52}
53
54impl serde::Serialize for ProtocolEvent {
55    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56    where
57        S: serde::Serializer,
58    {
59        #[derive(serde::Serialize)]
60        struct Tagged<'a> {
61            plugin_id: &'a str,
62            payload: &'a serde_json::Value,
63        }
64        Tagged {
65            plugin_id: &self.plugin_id,
66            payload: &self.payload,
67        }
68        .serialize(serializer)
69    }
70}
71
72impl<'de> serde::Deserialize<'de> for ProtocolEvent {
73    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74    where
75        D: serde::Deserializer<'de>,
76    {
77        let value = serde_json::Value::deserialize(deserializer)?;
78        if let Some(object) = value.as_object()
79            && let (Some(plugin_id), Some(payload)) =
80                (object.get("plugin_id"), object.get("payload"))
81        {
82            return Ok(Self {
83                plugin_id: plugin_id
84                    .as_str()
85                    .ok_or_else(|| serde::de::Error::custom("plugin_id must be a string"))?
86                    .to_string(),
87                payload: payload.clone(),
88            });
89        }
90        Err(serde::de::Error::custom(
91            "protocol events must be tagged with plugin_id and payload",
92        ))
93    }
94}
95
96pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ProtocolEvent>;
97
98pub const PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID: &str = "lash.plugin_runtime";
99
100#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
101pub struct PersistedPluginRuntimeEvent {
102    pub plugin_id: String,
103    pub event: crate::PluginRuntimeEvent,
104}
105
106pub fn plugin_runtime_protocol_event(
107    plugin_id: impl Into<String>,
108    event: crate::PluginRuntimeEvent,
109) -> Result<ProtocolEvent, serde_json::Error> {
110    ProtocolEvent::typed(
111        PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID,
112        PersistedPluginRuntimeEvent {
113            plugin_id: plugin_id.into(),
114            event,
115        },
116    )
117}
118
119pub fn plugin_runtime_event_from_protocol(
120    event: &ProtocolEvent,
121) -> Result<Option<PersistedPluginRuntimeEvent>, serde_json::Error> {
122    event.decode(PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID)
123}
124
125/// Send an event to the channel if it's still open.
126pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
127    if !tx.is_closed() {
128        let _ = tx.send(event).await;
129    }
130}
131
132pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
133    let message_id = fresh_message_id();
134    let mut parts = if plugin_message.parts.is_empty() {
135        vec![Part {
136            id: format!("{message_id}.p0"),
137            kind: PartKind::Text,
138            content: plugin_message.content.clone(),
139            attachment: None,
140            tool_call_id: None,
141            tool_name: None,
142            tool_replay: None,
143            prune_state: PruneState::Intact,
144            reasoning_meta: None,
145            response_meta: None,
146        }]
147    } else {
148        plugin_message.parts.clone()
149    };
150    reassign_part_ids(&message_id, &mut parts);
151    Message {
152        id: message_id,
153        role: plugin_message.role,
154        parts: Arc::new(parts),
155        origin: plugin_message.origin.clone().or_else(|| {
156            Some(crate::MessageOrigin::Plugin {
157                plugin_id: "plugin".to_string(),
158                transient: false,
159            })
160        }),
161    }
162}
163
164#[derive(Clone, Debug, Default, PartialEq, Eq)]
165pub struct SessionPolicy {
166    pub model: ModelSpec,
167    pub provider_id: String,
168    pub session_id: Option<String>,
169    pub autonomous: bool,
170    pub max_turns: Option<usize>,
171    pub prompt: crate::PromptLayer,
172}
173
174impl SessionPolicy {
175    pub fn recorded_provider_id(&self) -> &str {
176        self.provider_id.trim()
177    }
178
179    pub fn model_id(&self) -> &str {
180        &self.model.id
181    }
182
183    pub fn model_variant(&self) -> Option<&str> {
184        self.model.variant.as_deref()
185    }
186
187    pub fn context_window_tokens(&self) -> usize {
188        self.model.context_window_tokens()
189    }
190}
191
192impl serde::Serialize for SessionPolicy {
193    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194    where
195        S: serde::Serializer,
196    {
197        use serde::ser::SerializeStruct;
198
199        let mut fields = 5;
200        if !self.prompt.is_empty() {
201            fields += 1;
202        }
203        let mut state = serializer.serialize_struct("SessionPolicy", fields)?;
204        state.serialize_field("model", &self.model)?;
205        state.serialize_field("provider_id", self.recorded_provider_id())?;
206        state.serialize_field("session_id", &self.session_id)?;
207        state.serialize_field("autonomous", &self.autonomous)?;
208        state.serialize_field("max_turns", &self.max_turns)?;
209        if !self.prompt.is_empty() {
210            state.serialize_field("prompt", &self.prompt)?;
211        }
212        state.end()
213    }
214}
215
216impl<'de> serde::Deserialize<'de> for SessionPolicy {
217    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
218    where
219        D: serde::Deserializer<'de>,
220    {
221        #[derive(serde::Deserialize)]
222        #[serde(deny_unknown_fields)]
223        struct Wire {
224            #[serde(default)]
225            model: ModelSpec,
226            #[serde(default)]
227            provider_id: String,
228            #[serde(default)]
229            session_id: Option<String>,
230            #[serde(default)]
231            autonomous: bool,
232            #[serde(default)]
233            max_turns: Option<usize>,
234            #[serde(default)]
235            prompt: crate::PromptLayer,
236        }
237
238        let value = serde_json::Value::deserialize(deserializer)?;
239        if value
240            .as_object()
241            .is_some_and(|object| object.contains_key("provider"))
242        {
243            return Err(serde::de::Error::custom(
244                "legacy serialized provider config is not supported in session state; persist provider_id only",
245            ));
246        }
247        let wire = Wire::deserialize(value).map_err(serde::de::Error::custom)?;
248        Ok(Self {
249            model: wire.model,
250            provider_id: wire.provider_id,
251            session_id: wire.session_id,
252            autonomous: wire.autonomous,
253            max_turns: wire.max_turns,
254            prompt: wire.prompt,
255        })
256    }
257}
258
259/// Runtime-only policy resolved against host-owned live dependencies.
260#[derive(Clone, Debug, Default, PartialEq, Eq)]
261pub struct RuntimeSessionPolicy {
262    pub policy: SessionPolicy,
263    pub binding: ProviderBinding,
264}
265
266impl RuntimeSessionPolicy {
267    pub fn new(policy: SessionPolicy, binding: ProviderBinding) -> Self {
268        Self { policy, binding }
269    }
270
271    pub fn from_provider(
272        policy: SessionPolicy,
273        provider: ProviderHandle,
274    ) -> Result<Self, ProviderResolutionError> {
275        let binding = ProviderBinding::new(policy.recorded_provider_id(), provider)?;
276        Ok(Self { policy, binding })
277    }
278
279    pub fn provider(&self) -> &ProviderHandle {
280        &self.binding.provider
281    }
282
283    pub fn into_policy(self) -> SessionPolicy {
284        self.policy
285    }
286}
287
288impl std::ops::Deref for RuntimeSessionPolicy {
289    type Target = SessionPolicy;
290
291    fn deref(&self) -> &Self::Target {
292        &self.policy
293    }
294}
295
296impl std::ops::DerefMut for RuntimeSessionPolicy {
297    fn deref_mut(&mut self) -> &mut Self::Target {
298        &mut self.policy
299    }
300}
301
302/// Reusable session configuration overlay.
303///
304/// `SessionSpec` is the public configuration shape for callers that want to
305/// describe either a root session or a child session without constructing the
306/// persisted [`SessionPolicy`] directly.
307#[derive(Clone, Debug, PartialEq, Eq)]
308pub struct SessionSpec {
309    inherit: bool,
310    pub provider_id: Option<String>,
311    pub model: Option<ModelSpec>,
312    pub max_turns: Option<Option<usize>>,
313    pub prompt: Option<crate::PromptLayer>,
314}
315
316impl SessionSpec {
317    /// Create an explicit root-style spec. Unset fields resolve from the
318    /// runtime's core defaults.
319    pub fn new() -> Self {
320        Self {
321            inherit: false,
322            provider_id: None,
323            model: None,
324            max_turns: None,
325            prompt: None,
326        }
327    }
328
329    /// Create a parent-relative spec. Unset fields inherit from the live
330    /// parent policy at resolution time.
331    pub fn inherit() -> Self {
332        Self {
333            inherit: true,
334            ..Self::new()
335        }
336    }
337
338    pub fn inherits(&self) -> bool {
339        self.inherit
340    }
341
342    pub fn provider_id(mut self, provider_id: impl Into<String>) -> Self {
343        self.provider_id = Some(provider_id.into());
344        self
345    }
346
347    pub fn model(mut self, model: ModelSpec) -> Self {
348        self.model = Some(model);
349        self
350    }
351
352    pub fn max_turns(mut self, max_turns: usize) -> Self {
353        self.max_turns = Some(Some(max_turns));
354        self
355    }
356
357    pub fn clear_max_turns(mut self) -> Self {
358        self.max_turns = Some(None);
359        self
360    }
361
362    pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
363        self.prompt = Some(prompt);
364        self
365    }
366
367    pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
368        let mut policy = base.clone();
369        if let Some(provider_id) = self.provider_id.as_ref() {
370            policy.provider_id = provider_id.clone();
371        }
372        if let Some(model) = self.model.as_ref() {
373            policy.model = model.clone();
374        }
375        if let Some(max_turns) = self.max_turns {
376            policy.max_turns = max_turns;
377        }
378        if let Some(prompt) = self.prompt.as_ref() {
379            policy.prompt = prompt.clone();
380        }
381        policy
382    }
383}
384
385impl Default for SessionSpec {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391pub(crate) fn transport_stream_events(
392    provider: &ProviderHandle,
393    requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
394) -> Option<LlmEventSender> {
395    if let Some(requested) = requested {
396        return Some(make_stream_event_sender(requested));
397    }
398
399    if provider.requires_streaming() {
400        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
401        drop(rx);
402        Some(make_stream_event_sender(tx))
403    } else {
404        None
405    }
406}
407
408fn make_stream_event_sender(
409    tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
410) -> LlmEventSender {
411    LlmEventSender::new(move |event| {
412        let _ = tx.send(event);
413    })
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn protocol_event_writes_tagged_payload() {
422        let event = ProtocolEvent::typed("test_protocol", serde_json::json!({ "value": 42 }))
423            .expect("typed event");
424        let serialized = serde_json::to_value(event).expect("serialize");
425        assert_eq!(serialized["plugin_id"], "test_protocol");
426        assert!(serialized.get("payload").is_some());
427    }
428
429    #[test]
430    fn session_policy_rejects_legacy_provider_config() {
431        let err = serde_json::from_value::<SessionPolicy>(serde_json::json!({
432            "model": {},
433            "provider": {
434                "type": "openai",
435                "api_key": "must-not-load"
436            }
437        }))
438        .expect_err("legacy provider config must fail");
439
440        assert!(
441            err.to_string()
442                .contains("legacy serialized provider config is not supported")
443        );
444    }
445
446    #[test]
447    fn session_policy_serializes_provider_id_without_provider_handle() {
448        let policy = SessionPolicy {
449            provider_id: "mock-provider".to_string(),
450            model: ModelSpec::from_token_limits("mock-model", None, 200_000, None)
451                .expect("valid test model"),
452            ..SessionPolicy::default()
453        };
454
455        let value = serde_json::to_value(&policy).expect("serialize policy");
456
457        assert_eq!(value["provider_id"], "mock-provider");
458        assert!(value.get("provider").is_none());
459    }
460}