1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::ModelSpec;
9use crate::llm::types::{LlmEventSender, LlmStreamEvent};
10use crate::plugin::PluginMessage;
11use crate::provider::{ProviderBinding, ProviderHandle, ProviderResolutionError};
12
13pub use lash_sansio::format_tool_output_content;
14pub use lash_sansio::session_model::{
15 ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
16 PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
17 PruneState, SessionEvent, TokenUsage, TurnTerminationPolicyState, default_prompt_template,
18 make_error_envelope, make_error_event, reassign_part_ids, render_prompt,
19 render_transcript_prompt, shared_parts,
20};
21
22pub fn fresh_message_id() -> String {
23 format!("m{}", uuid::Uuid::new_v4().simple())
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub struct ProtocolEvent {
28 pub plugin_id: String,
29 pub payload: serde_json::Value,
30}
31
32impl ProtocolEvent {
33 pub fn typed<T>(plugin_id: impl Into<String>, event: T) -> Result<Self, serde_json::Error>
34 where
35 T: serde::Serialize,
36 {
37 Ok(Self {
38 plugin_id: plugin_id.into(),
39 payload: serde_json::to_value(event)?,
40 })
41 }
42
43 pub fn decode<T>(&self, expected_plugin_id: &str) -> Result<Option<T>, serde_json::Error>
44 where
45 T: for<'de> serde::Deserialize<'de>,
46 {
47 if self.plugin_id != expected_plugin_id {
48 return Ok(None);
49 }
50 serde_json::from_value(self.payload.clone()).map(Some)
51 }
52}
53
54impl serde::Serialize for ProtocolEvent {
55 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56 where
57 S: serde::Serializer,
58 {
59 #[derive(serde::Serialize)]
60 struct Tagged<'a> {
61 plugin_id: &'a str,
62 payload: &'a serde_json::Value,
63 }
64 Tagged {
65 plugin_id: &self.plugin_id,
66 payload: &self.payload,
67 }
68 .serialize(serializer)
69 }
70}
71
72impl<'de> serde::Deserialize<'de> for ProtocolEvent {
73 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74 where
75 D: serde::Deserializer<'de>,
76 {
77 let value = serde_json::Value::deserialize(deserializer)?;
78 if let Some(object) = value.as_object()
79 && let (Some(plugin_id), Some(payload)) =
80 (object.get("plugin_id"), object.get("payload"))
81 {
82 return Ok(Self {
83 plugin_id: plugin_id
84 .as_str()
85 .ok_or_else(|| serde::de::Error::custom("plugin_id must be a string"))?
86 .to_string(),
87 payload: payload.clone(),
88 });
89 }
90 Err(serde::de::Error::custom(
91 "protocol events must be tagged with plugin_id and payload",
92 ))
93 }
94}
95
96pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ProtocolEvent>;
97
98pub const PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID: &str = "lash.plugin_runtime";
99
100#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
101pub struct PersistedPluginRuntimeEvent {
102 pub plugin_id: String,
103 pub event: crate::PluginRuntimeEvent,
104}
105
106pub fn plugin_runtime_protocol_event(
107 plugin_id: impl Into<String>,
108 event: crate::PluginRuntimeEvent,
109) -> Result<ProtocolEvent, serde_json::Error> {
110 ProtocolEvent::typed(
111 PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID,
112 PersistedPluginRuntimeEvent {
113 plugin_id: plugin_id.into(),
114 event,
115 },
116 )
117}
118
119pub fn plugin_runtime_event_from_protocol(
120 event: &ProtocolEvent,
121) -> Result<Option<PersistedPluginRuntimeEvent>, serde_json::Error> {
122 event.decode(PLUGIN_RUNTIME_PROTOCOL_PLUGIN_ID)
123}
124
125pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
127 if !tx.is_closed() {
128 let _ = tx.send(event).await;
129 }
130}
131
132pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
133 let message_id = fresh_message_id();
134 let mut parts = if plugin_message.parts.is_empty() {
135 vec![Part {
136 id: format!("{message_id}.p0"),
137 kind: PartKind::Text,
138 content: plugin_message.content.clone(),
139 attachment: None,
140 tool_call_id: None,
141 tool_name: None,
142 tool_replay: None,
143 prune_state: PruneState::Intact,
144 reasoning_meta: None,
145 response_meta: None,
146 }]
147 } else {
148 plugin_message.parts.clone()
149 };
150 reassign_part_ids(&message_id, &mut parts);
151 Message {
152 id: message_id,
153 role: plugin_message.role,
154 parts: Arc::new(parts),
155 origin: plugin_message.origin.clone().or_else(|| {
156 Some(crate::MessageOrigin::Plugin {
157 plugin_id: "plugin".to_string(),
158 transient: false,
159 })
160 }),
161 }
162}
163
164#[derive(Clone, Debug, Default, PartialEq, Eq)]
165pub struct SessionPolicy {
166 pub model: ModelSpec,
167 pub provider_id: String,
168 pub session_id: Option<String>,
169 pub autonomous: bool,
170 pub max_turns: Option<usize>,
171 pub prompt: crate::PromptLayer,
172}
173
174impl SessionPolicy {
175 pub fn recorded_provider_id(&self) -> &str {
176 self.provider_id.trim()
177 }
178
179 pub fn model_id(&self) -> &str {
180 &self.model.id
181 }
182
183 pub fn model_variant(&self) -> Option<&str> {
184 self.model.variant.as_deref()
185 }
186
187 pub fn context_window_tokens(&self) -> usize {
188 self.model.context_window_tokens()
189 }
190}
191
192impl serde::Serialize for SessionPolicy {
193 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: serde::Serializer,
196 {
197 use serde::ser::SerializeStruct;
198
199 let mut fields = 5;
200 if !self.prompt.is_empty() {
201 fields += 1;
202 }
203 let mut state = serializer.serialize_struct("SessionPolicy", fields)?;
204 state.serialize_field("model", &self.model)?;
205 state.serialize_field("provider_id", self.recorded_provider_id())?;
206 state.serialize_field("session_id", &self.session_id)?;
207 state.serialize_field("autonomous", &self.autonomous)?;
208 state.serialize_field("max_turns", &self.max_turns)?;
209 if !self.prompt.is_empty() {
210 state.serialize_field("prompt", &self.prompt)?;
211 }
212 state.end()
213 }
214}
215
216impl<'de> serde::Deserialize<'de> for SessionPolicy {
217 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
218 where
219 D: serde::Deserializer<'de>,
220 {
221 #[derive(serde::Deserialize)]
222 #[serde(deny_unknown_fields)]
223 struct Wire {
224 #[serde(default)]
225 model: ModelSpec,
226 #[serde(default)]
227 provider_id: String,
228 #[serde(default)]
229 session_id: Option<String>,
230 #[serde(default)]
231 autonomous: bool,
232 #[serde(default)]
233 max_turns: Option<usize>,
234 #[serde(default)]
235 prompt: crate::PromptLayer,
236 }
237
238 let value = serde_json::Value::deserialize(deserializer)?;
239 if value
240 .as_object()
241 .is_some_and(|object| object.contains_key("provider"))
242 {
243 return Err(serde::de::Error::custom(
244 "legacy serialized provider config is not supported in session state; persist provider_id only",
245 ));
246 }
247 let wire = Wire::deserialize(value).map_err(serde::de::Error::custom)?;
248 Ok(Self {
249 model: wire.model,
250 provider_id: wire.provider_id,
251 session_id: wire.session_id,
252 autonomous: wire.autonomous,
253 max_turns: wire.max_turns,
254 prompt: wire.prompt,
255 })
256 }
257}
258
259#[derive(Clone, Debug, Default, PartialEq, Eq)]
261pub struct RuntimeSessionPolicy {
262 pub policy: SessionPolicy,
263 pub binding: ProviderBinding,
264}
265
266impl RuntimeSessionPolicy {
267 pub fn new(policy: SessionPolicy, binding: ProviderBinding) -> Self {
268 Self { policy, binding }
269 }
270
271 pub fn from_provider(
272 policy: SessionPolicy,
273 provider: ProviderHandle,
274 ) -> Result<Self, ProviderResolutionError> {
275 let binding = ProviderBinding::new(policy.recorded_provider_id(), provider)?;
276 Ok(Self { policy, binding })
277 }
278
279 pub fn provider(&self) -> &ProviderHandle {
280 &self.binding.provider
281 }
282
283 pub fn into_policy(self) -> SessionPolicy {
284 self.policy
285 }
286}
287
288impl std::ops::Deref for RuntimeSessionPolicy {
289 type Target = SessionPolicy;
290
291 fn deref(&self) -> &Self::Target {
292 &self.policy
293 }
294}
295
296impl std::ops::DerefMut for RuntimeSessionPolicy {
297 fn deref_mut(&mut self) -> &mut Self::Target {
298 &mut self.policy
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq)]
308pub struct SessionSpec {
309 inherit: bool,
310 pub provider_id: Option<String>,
311 pub model: Option<ModelSpec>,
312 pub max_turns: Option<Option<usize>>,
313 pub prompt: Option<crate::PromptLayer>,
314}
315
316impl SessionSpec {
317 pub fn new() -> Self {
320 Self {
321 inherit: false,
322 provider_id: None,
323 model: None,
324 max_turns: None,
325 prompt: None,
326 }
327 }
328
329 pub fn inherit() -> Self {
332 Self {
333 inherit: true,
334 ..Self::new()
335 }
336 }
337
338 pub fn inherits(&self) -> bool {
339 self.inherit
340 }
341
342 pub fn provider_id(mut self, provider_id: impl Into<String>) -> Self {
343 self.provider_id = Some(provider_id.into());
344 self
345 }
346
347 pub fn model(mut self, model: ModelSpec) -> Self {
348 self.model = Some(model);
349 self
350 }
351
352 pub fn max_turns(mut self, max_turns: usize) -> Self {
353 self.max_turns = Some(Some(max_turns));
354 self
355 }
356
357 pub fn clear_max_turns(mut self) -> Self {
358 self.max_turns = Some(None);
359 self
360 }
361
362 pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
363 self.prompt = Some(prompt);
364 self
365 }
366
367 pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
368 let mut policy = base.clone();
369 if let Some(provider_id) = self.provider_id.as_ref() {
370 policy.provider_id = provider_id.clone();
371 }
372 if let Some(model) = self.model.as_ref() {
373 policy.model = model.clone();
374 }
375 if let Some(max_turns) = self.max_turns {
376 policy.max_turns = max_turns;
377 }
378 if let Some(prompt) = self.prompt.as_ref() {
379 policy.prompt = prompt.clone();
380 }
381 policy
382 }
383}
384
385impl Default for SessionSpec {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391pub(crate) fn transport_stream_events(
392 provider: &ProviderHandle,
393 requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
394) -> Option<LlmEventSender> {
395 if let Some(requested) = requested {
396 return Some(make_stream_event_sender(requested));
397 }
398
399 if provider.requires_streaming() {
400 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
401 drop(rx);
402 Some(make_stream_event_sender(tx))
403 } else {
404 None
405 }
406}
407
408fn make_stream_event_sender(
409 tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
410) -> LlmEventSender {
411 LlmEventSender::new(move |event| {
412 let _ = tx.send(event);
413 })
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn protocol_event_writes_tagged_payload() {
422 let event = ProtocolEvent::typed("test_protocol", serde_json::json!({ "value": 42 }))
423 .expect("typed event");
424 let serialized = serde_json::to_value(event).expect("serialize");
425 assert_eq!(serialized["plugin_id"], "test_protocol");
426 assert!(serialized.get("payload").is_some());
427 }
428
429 #[test]
430 fn session_policy_rejects_legacy_provider_config() {
431 let err = serde_json::from_value::<SessionPolicy>(serde_json::json!({
432 "model": {},
433 "provider": {
434 "type": "openai",
435 "api_key": "must-not-load"
436 }
437 }))
438 .expect_err("legacy provider config must fail");
439
440 assert!(
441 err.to_string()
442 .contains("legacy serialized provider config is not supported")
443 );
444 }
445
446 #[test]
447 fn session_policy_serializes_provider_id_without_provider_handle() {
448 let policy = SessionPolicy {
449 provider_id: "mock-provider".to_string(),
450 model: ModelSpec::from_token_limits("mock-model", None, 200_000, None)
451 .expect("valid test model"),
452 ..SessionPolicy::default()
453 };
454
455 let value = serde_json::to_value(&policy).expect("serialize policy");
456
457 assert_eq!(value["provider_id"], "mock-provider");
458 assert!(value.get("provider").is_none());
459 }
460}