1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::ModelSpec;
9use crate::llm::types::{LlmEventSender, LlmStreamEvent};
10use crate::plugin::PluginMessage;
11use crate::provider::{ProviderBinding, ProviderHandle, ProviderResolutionError};
12
13pub use lash_sansio::format_tool_output_content;
14pub use lash_sansio::session_model::{
15 ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
16 PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
17 PruneState, SessionEvent, TokenUsage, TurnTerminationPolicyState, default_prompt_template,
18 make_error_envelope, make_error_event, reassign_part_ids, render_prompt,
19 render_transcript_prompt, shared_parts,
20};
21
22pub fn fresh_message_id() -> String {
23 format!("m{}", uuid::Uuid::new_v4().simple())
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub struct ProtocolEvent {
28 pub plugin_id: String,
29 pub payload: serde_json::Value,
30}
31
32impl ProtocolEvent {
33 pub fn typed<T>(plugin_id: impl Into<String>, event: T) -> Result<Self, serde_json::Error>
34 where
35 T: serde::Serialize,
36 {
37 Ok(Self {
38 plugin_id: plugin_id.into(),
39 payload: serde_json::to_value(event)?,
40 })
41 }
42
43 pub fn decode<T>(&self, expected_plugin_id: &str) -> Result<Option<T>, serde_json::Error>
44 where
45 T: for<'de> serde::Deserialize<'de>,
46 {
47 if self.plugin_id != expected_plugin_id {
48 return Ok(None);
49 }
50 serde_json::from_value(self.payload.clone()).map(Some)
51 }
52}
53
54impl serde::Serialize for ProtocolEvent {
55 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56 where
57 S: serde::Serializer,
58 {
59 #[derive(serde::Serialize)]
60 struct Tagged<'a> {
61 plugin_id: &'a str,
62 payload: &'a serde_json::Value,
63 }
64 Tagged {
65 plugin_id: &self.plugin_id,
66 payload: &self.payload,
67 }
68 .serialize(serializer)
69 }
70}
71
72impl<'de> serde::Deserialize<'de> for ProtocolEvent {
73 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74 where
75 D: serde::Deserializer<'de>,
76 {
77 let value = serde_json::Value::deserialize(deserializer)?;
78 if let Some(object) = value.as_object()
79 && let (Some(plugin_id), Some(payload)) =
80 (object.get("plugin_id"), object.get("payload"))
81 {
82 return Ok(Self {
83 plugin_id: plugin_id
84 .as_str()
85 .ok_or_else(|| serde::de::Error::custom("plugin_id must be a string"))?
86 .to_string(),
87 payload: payload.clone(),
88 });
89 }
90 Err(serde::de::Error::custom(
91 "protocol events must be tagged with plugin_id and payload",
92 ))
93 }
94}
95
96pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ProtocolEvent>;
97
98pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
100 if !tx.is_closed() {
101 let _ = tx.send(event).await;
102 }
103}
104
105pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
106 let message_id = fresh_message_id();
107 let mut parts = if plugin_message.parts.is_empty() {
108 vec![Part {
109 id: format!("{message_id}.p0"),
110 kind: PartKind::Text,
111 content: plugin_message.content.clone(),
112 attachment: None,
113 tool_call_id: None,
114 tool_name: None,
115 tool_replay: None,
116 prune_state: PruneState::Intact,
117 reasoning_meta: None,
118 response_meta: None,
119 }]
120 } else {
121 plugin_message.parts.clone()
122 };
123 reassign_part_ids(&message_id, &mut parts);
124 Message {
125 id: message_id,
126 role: plugin_message.role,
127 parts: Arc::new(parts),
128 origin: plugin_message.origin.clone().or_else(|| {
129 Some(crate::MessageOrigin::Plugin {
130 plugin_id: "plugin".to_string(),
131 transient: false,
132 })
133 }),
134 }
135}
136
137#[derive(Clone, Debug, Default, PartialEq, Eq)]
138pub struct SessionPolicy {
139 pub model: ModelSpec,
140 pub provider_id: String,
141 pub session_id: Option<String>,
142 pub autonomous: bool,
143 pub max_turns: Option<usize>,
144 pub prompt: crate::PromptLayer,
145}
146
147impl SessionPolicy {
148 pub fn recorded_provider_id(&self) -> &str {
149 self.provider_id.trim()
150 }
151
152 pub fn model_id(&self) -> &str {
153 &self.model.id
154 }
155
156 pub fn model_variant(&self) -> Option<&str> {
157 self.model.variant.as_deref()
158 }
159
160 pub fn context_window_tokens(&self) -> usize {
161 self.model.context_window_tokens()
162 }
163}
164
165impl serde::Serialize for SessionPolicy {
166 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
167 where
168 S: serde::Serializer,
169 {
170 use serde::ser::SerializeStruct;
171
172 let mut fields = 5;
173 if !self.prompt.is_empty() {
174 fields += 1;
175 }
176 let mut state = serializer.serialize_struct("SessionPolicy", fields)?;
177 state.serialize_field("model", &self.model)?;
178 state.serialize_field("provider_id", self.recorded_provider_id())?;
179 state.serialize_field("session_id", &self.session_id)?;
180 state.serialize_field("autonomous", &self.autonomous)?;
181 state.serialize_field("max_turns", &self.max_turns)?;
182 if !self.prompt.is_empty() {
183 state.serialize_field("prompt", &self.prompt)?;
184 }
185 state.end()
186 }
187}
188
189impl<'de> serde::Deserialize<'de> for SessionPolicy {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: serde::Deserializer<'de>,
193 {
194 #[derive(serde::Deserialize)]
195 #[serde(deny_unknown_fields)]
196 struct Wire {
197 #[serde(default)]
198 model: ModelSpec,
199 #[serde(default)]
200 provider_id: String,
201 #[serde(default)]
202 session_id: Option<String>,
203 #[serde(default)]
204 autonomous: bool,
205 #[serde(default)]
206 max_turns: Option<usize>,
207 #[serde(default)]
208 prompt: crate::PromptLayer,
209 }
210
211 let value = serde_json::Value::deserialize(deserializer)?;
212 if value
213 .as_object()
214 .is_some_and(|object| object.contains_key("provider"))
215 {
216 return Err(serde::de::Error::custom(
217 "legacy serialized provider config is not supported in session state; persist provider_id only",
218 ));
219 }
220 let wire = Wire::deserialize(value).map_err(serde::de::Error::custom)?;
221 Ok(Self {
222 model: wire.model,
223 provider_id: wire.provider_id,
224 session_id: wire.session_id,
225 autonomous: wire.autonomous,
226 max_turns: wire.max_turns,
227 prompt: wire.prompt,
228 })
229 }
230}
231
232#[derive(Clone, Debug, Default, PartialEq, Eq)]
234pub struct RuntimeSessionPolicy {
235 pub policy: SessionPolicy,
236 pub binding: ProviderBinding,
237}
238
239impl RuntimeSessionPolicy {
240 pub fn new(policy: SessionPolicy, binding: ProviderBinding) -> Self {
241 Self { policy, binding }
242 }
243
244 pub fn from_provider(
245 policy: SessionPolicy,
246 provider: ProviderHandle,
247 ) -> Result<Self, ProviderResolutionError> {
248 let binding = ProviderBinding::new(policy.recorded_provider_id(), provider)?;
249 Ok(Self { policy, binding })
250 }
251
252 pub fn provider(&self) -> &ProviderHandle {
253 &self.binding.provider
254 }
255
256 pub fn into_policy(self) -> SessionPolicy {
257 self.policy
258 }
259}
260
261impl std::ops::Deref for RuntimeSessionPolicy {
262 type Target = SessionPolicy;
263
264 fn deref(&self) -> &Self::Target {
265 &self.policy
266 }
267}
268
269impl std::ops::DerefMut for RuntimeSessionPolicy {
270 fn deref_mut(&mut self) -> &mut Self::Target {
271 &mut self.policy
272 }
273}
274
275#[derive(Clone, Debug, PartialEq, Eq)]
281pub struct SessionSpec {
282 inherit: bool,
283 pub provider_id: Option<String>,
284 pub model: Option<ModelSpec>,
285 pub max_turns: Option<Option<usize>>,
286 pub prompt: Option<crate::PromptLayer>,
287}
288
289impl SessionSpec {
290 pub fn new() -> Self {
293 Self {
294 inherit: false,
295 provider_id: None,
296 model: None,
297 max_turns: None,
298 prompt: None,
299 }
300 }
301
302 pub fn inherit() -> Self {
305 Self {
306 inherit: true,
307 ..Self::new()
308 }
309 }
310
311 pub fn inherits(&self) -> bool {
312 self.inherit
313 }
314
315 pub fn provider_id(mut self, provider_id: impl Into<String>) -> Self {
316 self.provider_id = Some(provider_id.into());
317 self
318 }
319
320 pub fn model(mut self, model: ModelSpec) -> Self {
321 self.model = Some(model);
322 self
323 }
324
325 pub fn max_turns(mut self, max_turns: usize) -> Self {
326 self.max_turns = Some(Some(max_turns));
327 self
328 }
329
330 pub fn clear_max_turns(mut self) -> Self {
331 self.max_turns = Some(None);
332 self
333 }
334
335 pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
336 self.prompt = Some(prompt);
337 self
338 }
339
340 pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
341 let mut policy = base.clone();
342 if let Some(provider_id) = self.provider_id.as_ref() {
343 policy.provider_id = provider_id.clone();
344 }
345 if let Some(model) = self.model.as_ref() {
346 policy.model = model.clone();
347 }
348 if let Some(max_turns) = self.max_turns {
349 policy.max_turns = max_turns;
350 }
351 if let Some(prompt) = self.prompt.as_ref() {
352 policy.prompt = prompt.clone();
353 }
354 policy
355 }
356}
357
358impl Default for SessionSpec {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364pub(crate) fn transport_stream_events(
365 provider: &ProviderHandle,
366 requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
367) -> Option<LlmEventSender> {
368 if let Some(requested) = requested {
369 return Some(make_stream_event_sender(requested));
370 }
371
372 if provider.requires_streaming() {
373 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
374 drop(rx);
375 Some(make_stream_event_sender(tx))
376 } else {
377 None
378 }
379}
380
381fn make_stream_event_sender(
382 tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
383) -> LlmEventSender {
384 LlmEventSender::new(move |event| {
385 let _ = tx.send(event);
386 })
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn protocol_event_writes_tagged_payload() {
395 let event = ProtocolEvent::typed("test_protocol", serde_json::json!({ "value": 42 }))
396 .expect("typed event");
397 let serialized = serde_json::to_value(event).expect("serialize");
398 assert_eq!(serialized["plugin_id"], "test_protocol");
399 assert!(serialized.get("payload").is_some());
400 }
401
402 #[test]
403 fn session_policy_rejects_legacy_provider_config() {
404 let err = serde_json::from_value::<SessionPolicy>(serde_json::json!({
405 "model": {},
406 "provider": {
407 "type": "openai",
408 "api_key": "must-not-load"
409 }
410 }))
411 .expect_err("legacy provider config must fail");
412
413 assert!(
414 err.to_string()
415 .contains("legacy serialized provider config is not supported")
416 );
417 }
418
419 #[test]
420 fn session_policy_serializes_provider_id_without_provider_handle() {
421 let policy = SessionPolicy {
422 provider_id: "mock-provider".to_string(),
423 model: ModelSpec::from_token_limits("mock-model", None, 200_000, None)
424 .expect("valid test model"),
425 ..SessionPolicy::default()
426 };
427
428 let value = serde_json::to_value(&policy).expect("serialize policy");
429
430 assert_eq!(value["provider_id"], "mock-provider");
431 assert!(value.get("provider").is_none());
432 }
433}