telltale_language/extensions/
timeout.rs1use super::{
7 CodegenContext, ExtensionRegistry, ExtensionValidationError, GrammarExtension, ParseContext,
8 ParseError, ProjectionContext, ProtocolExtension, StatementParser,
9};
10use crate::ast::{LocalType, Role};
11use crate::compiler::projection::ProjectionError;
12use std::any::{Any, TypeId};
13use std::time::Duration;
14
15#[derive(Debug)]
17pub struct TimeoutGrammarExtension;
18
19impl GrammarExtension for TimeoutGrammarExtension {
20 fn grammar_rules(&self) -> &'static str {
21 r#"
22timeout_stmt = { "timeout" ~ timeout_duration ~ timeout_roles ~ "{" ~ protocol_body ~ "}" }
23timeout_duration = { integer ~ time_unit? }
24time_unit = { "ms" | "s" | "m" | "h" }
25timeout_roles = { "(" ~ role_list ~ ")" | role_ref }
26"#
27 }
28
29 fn statement_rules(&self) -> Vec<&'static str> {
30 vec!["timeout_stmt"]
31 }
32
33 fn priority(&self) -> u32 {
34 200 }
36
37 fn extension_id(&self) -> &'static str {
38 "timeout"
39 }
40}
41
42#[derive(Debug)]
44pub struct TimeoutStatementParser;
45
46impl StatementParser for TimeoutStatementParser {
47 fn can_parse(&self, rule_name: &str) -> bool {
48 rule_name == "timeout_stmt"
49 }
50
51 fn supported_rules(&self) -> Vec<String> {
52 vec!["timeout_stmt".to_string()]
53 }
54
55 fn parse_statement(
56 &self,
57 rule_name: &str,
58 _content: &str,
59 context: &ParseContext,
60 ) -> Result<Box<dyn ProtocolExtension>, ParseError> {
61 if rule_name != "timeout_stmt" {
62 return Err(ParseError::InvalidSyntax {
63 details: format!("Expected timeout_stmt, got {}", rule_name),
64 });
65 }
66
67 let timeout_protocol = self.parse_timeout_content(_content, context)?;
70 Ok(Box::new(timeout_protocol))
71 }
72}
73
74impl TimeoutStatementParser {
75 fn parse_timeout_content(
76 &self,
77 content: &str,
78 context: &ParseContext,
79 ) -> Result<TimeoutProtocol, ParseError> {
80 let duration_ms = self.extract_duration(content)?;
82 let roles = self.extract_roles(content, context)?;
83
84 Ok(TimeoutProtocol {
86 duration: Duration::from_millis(duration_ms),
87 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
88 body_repr: "End".to_string(),
89 })
90 }
91
92 fn extract_duration(&self, content: &str) -> Result<u64, ParseError> {
93 let duration_str = content
96 .split_whitespace()
97 .find(|s| s.chars().all(|c| c.is_ascii_digit()))
98 .ok_or_else(|| ParseError::InvalidSyntax {
99 details: "Could not find timeout duration".to_string(),
100 })?;
101
102 duration_str.parse().map_err(|_| ParseError::InvalidSyntax {
103 details: "Invalid timeout duration format".to_string(),
104 })
105 }
106
107 fn extract_roles(
108 &self,
109 _content: &str,
110 context: &ParseContext,
111 ) -> Result<Vec<Role>, ParseError> {
112 Ok(context.declared_roles.to_vec())
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct TimeoutProtocol {
121 pub duration: Duration,
122 pub role_names: Vec<String>, pub body_repr: String,
126}
127
128impl ProtocolExtension for TimeoutProtocol {
129 fn type_name(&self) -> &'static str {
130 "TimeoutProtocol"
131 }
132
133 fn mentions_role(&self, role: &Role) -> bool {
134 self.role_names
135 .iter()
136 .any(|name| name == &role.name().to_string())
137 }
138
139 fn validate(&self, all_roles: &[Role]) -> Result<(), ExtensionValidationError> {
140 for role_name in &self.role_names {
142 if !all_roles.iter().any(|r| &r.name().to_string() == role_name) {
143 return Err(ExtensionValidationError::UndeclaredRole {
144 role: role_name.clone(),
145 });
146 }
147 }
148
149 if self.duration.is_zero() {
151 return Err(ExtensionValidationError::InvalidStructure {
152 reason: "Timeout duration cannot be zero".to_string(),
153 });
154 }
155
156 if self.duration > Duration::from_secs(3600) {
157 return Err(ExtensionValidationError::InvalidStructure {
158 reason: "Timeout duration too long (max 1 hour)".to_string(),
159 });
160 }
161
162 Ok(())
163 }
164
165 fn project(
166 &self,
167 role: &Role,
168 _context: &ProjectionContext,
169 ) -> Result<LocalType, ProjectionError> {
170 if self
171 .role_names
172 .iter()
173 .any(|name| name == &role.name().to_string())
174 {
175 Ok(LocalType::Timeout {
178 duration: self.duration,
179 body: Box::new(LocalType::End),
180 })
181 } else {
182 Ok(LocalType::End)
184 }
185 }
186
187 fn generate_code(&self, _context: &CodegenContext) -> proc_macro2::TokenStream {
188 let duration_ms = u64::try_from(self.duration.as_millis()).unwrap_or(u64::MAX);
189 let _role_names = &self.role_names;
190
191 quote::quote! {
192 .with_timeout(
194 Duration::from_millis(#duration_ms),
195 )
197 }
198 }
199
200 fn as_any(&self) -> &dyn Any {
201 self
202 }
203
204 fn as_any_mut(&mut self) -> &mut dyn Any {
205 self
206 }
207
208 fn type_id(&self) -> TypeId {
209 TypeId::of::<Self>()
210 }
211}
212
213pub fn register_timeout_extension(
219 registry: &mut ExtensionRegistry,
220) -> Result<(), crate::extensions::ParseError> {
221 registry.register_grammar(TimeoutGrammarExtension)?;
222 registry.register_parser(TimeoutStatementParser, "timeout".to_string());
223 Ok(())
224}
225
226impl LocalType {
228 pub fn timeout(duration: Duration, body: LocalType) -> Self {
229 Self::Timeout {
230 duration,
231 body: Box::new(body),
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_timeout_grammar_extension() {
242 let ext = TimeoutGrammarExtension;
243 assert_eq!(ext.extension_id(), "timeout");
244 assert!(ext.statement_rules().contains(&"timeout_stmt"));
245 assert!(ext.grammar_rules().contains("timeout_stmt"));
246 }
247
248 #[test]
249 fn test_timeout_statement_parser() {
250 let parser = TimeoutStatementParser;
251 assert!(parser.can_parse("timeout_stmt"));
252 assert!(!parser.can_parse("unknown_stmt"));
253 }
254
255 #[test]
256 fn test_timeout_protocol() {
257 let timeout_protocol = TimeoutProtocol {
258 duration: Duration::from_millis(5000),
259 role_names: vec!["Alice".to_string()],
260 body_repr: "End".to_string(),
261 };
262
263 assert_eq!(timeout_protocol.type_name(), "TimeoutProtocol");
264
265 use proc_macro2::Span;
266 let alice = Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap();
267 let bob = Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap();
268
269 assert!(timeout_protocol.mentions_role(&alice));
270 assert!(!timeout_protocol.mentions_role(&bob));
271 }
272
273 #[test]
274 fn test_timeout_validation() {
275 use proc_macro2::Span;
276 let roles = vec![Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap()];
277
278 let valid_timeout = TimeoutProtocol {
279 duration: Duration::from_millis(5000),
280 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
281 body_repr: "End".to_string(),
282 };
283
284 assert!(valid_timeout.validate(&roles).is_ok());
285
286 let invalid_timeout = TimeoutProtocol {
288 duration: Duration::ZERO,
289 role_names: roles.iter().map(|r| r.name().to_string()).collect(),
290 body_repr: "End".to_string(),
291 };
292
293 assert!(invalid_timeout.validate(&roles).is_err());
294 }
295
296 #[test]
297 fn test_extension_registration() {
298 let mut registry = ExtensionRegistry::new();
299 register_timeout_extension(&mut registry).expect("extension should register");
300
301 assert!(registry.can_handle("timeout_stmt"));
302 }
303}