1use crate::support::*;
2
3#[derive(
4 Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
5)]
6pub struct ModeId(String);
7
8impl ModeId {
9 pub fn new(mode: impl Into<String>) -> Self {
10 Self(mode.into())
11 }
12
13 pub fn standard() -> Self {
14 Self("standard".to_string())
15 }
16
17 pub fn rlm() -> Self {
18 Self("rlm".to_string())
19 }
20
21 pub fn as_str(&self) -> &str {
22 &self.0
23 }
24}
25
26impl fmt::Display for ModeId {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.write_str(self.as_str())
29 }
30}
31
32#[derive(Clone)]
34pub struct ModePreset {
35 pub(crate) mode_id: ModeId,
36 pub(crate) factory: Arc<dyn PluginFactory>,
37}
38
39impl ModePreset {
40 pub fn standard() -> Self {
41 Self {
42 mode_id: ModeId::standard(),
43 factory: Arc::new(lash_protocol_standard::StandardProtocolPluginFactory::new()),
44 }
45 }
46
47 pub fn rlm() -> Self {
48 Self {
49 mode_id: ModeId::rlm(),
50 factory: Arc::new(lash_protocol_rlm::RlmProtocolPluginFactory::new(
51 rlm_preset_config(lash_protocol_rlm::RlmProtocolPluginConfig::default()),
52 )),
53 }
54 }
55
56 pub fn rlm_with_config(config: lash_protocol_rlm::RlmProtocolPluginConfig) -> Self {
57 Self {
58 mode_id: ModeId::rlm(),
59 factory: Arc::new(lash_protocol_rlm::RlmProtocolPluginFactory::new(
60 rlm_preset_config(config),
61 )),
62 }
63 }
64
65 pub fn rlm_with_projection_resolver(
66 projection_resolver: Arc<dyn lash_protocol_rlm::ProjectionResolver>,
67 ) -> Self {
68 Self::rlm_with_config_and_projection_resolver(
69 lash_protocol_rlm::RlmProtocolPluginConfig::default(),
70 projection_resolver,
71 )
72 }
73
74 pub fn rlm_with_config_and_projection_resolver(
75 config: lash_protocol_rlm::RlmProtocolPluginConfig,
76 projection_resolver: Arc<dyn lash_protocol_rlm::ProjectionResolver>,
77 ) -> Self {
78 Self {
79 mode_id: ModeId::rlm(),
80 factory: Arc::new(
81 lash_protocol_rlm::RlmProtocolPluginFactory::new(rlm_preset_config(config))
82 .with_projection_resolver(projection_resolver),
83 ),
84 }
85 }
86
87 pub fn mode_id(&self) -> &ModeId {
88 &self.mode_id
89 }
90}
91
92pub trait RlmTurnBuilderExt: Sized {
93 fn require_submit(self) -> Result<Self>;
94 fn require_submit_schema(self, schema: serde_json::Value) -> Result<Self>;
95 fn allow_prose_or_submit(self) -> Result<Self>;
96}
97
98impl RlmTurnBuilderExt for TurnBuilder {
99 fn require_submit(self) -> Result<Self> {
100 rlm_termination(
101 self,
102 lash_rlm_types::RlmTermination::SubmitRequired { schema: None },
103 )
104 }
105
106 fn require_submit_schema(self, schema: serde_json::Value) -> Result<Self> {
107 rlm_termination(
108 self,
109 lash_rlm_types::RlmTermination::SubmitRequired {
110 schema: Some(schema),
111 },
112 )
113 }
114
115 fn allow_prose_or_submit(self) -> Result<Self> {
116 rlm_termination(self, lash_rlm_types::RlmTermination::ProseOrSubmit)
117 }
118}
119
120pub trait RlmSessionBuilderExt: Sized {
121 fn final_answer_format(self, format: lash_rlm_types::RlmFinalAnswerFormat) -> Self;
122}
123
124impl RlmSessionBuilderExt for SessionBuilder {
125 fn final_answer_format(mut self, format: lash_rlm_types::RlmFinalAnswerFormat) -> Self {
126 self.rlm_final_answer_format = Some(format);
127 self
128 }
129}
130
131fn rlm_termination(
132 mut builder: TurnBuilder,
133 termination: lash_rlm_types::RlmTermination,
134) -> Result<TurnBuilder> {
135 let override_options = ProtocolTurnOptions::typed(lash_rlm_types::RlmCreateExtras {
136 termination,
137 final_answer_format: None,
138 })?;
139 let options = builder
140 .protocol_turn_options
141 .as_ref()
142 .map(|current| current.merged_with_override(&override_options))
143 .unwrap_or(override_options);
144 builder.protocol_turn_options = Some(options);
145 Ok(builder)
146}
147
148fn rlm_preset_config(
149 config: lash_protocol_rlm::RlmProtocolPluginConfig,
150) -> lash_protocol_rlm::RlmProtocolPluginConfig {
151 let language_features = config.lashlang_language_features.with_label_annotations();
152 config.with_lashlang_language_features(language_features)
153}