1use crate::error::{IronError, Result};
7use std::collections::{HashMap, HashSet};
8use std::time::{Duration, SystemTime};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum Capability {
17 MessageTags,
19 ServerTime,
20 AccountNotify,
21 AccountTag,
22 AwayNotify,
23 Batch,
24 CapNotify,
25 ChgHost,
26 EchoMessage,
27 ExtendedJoin,
28 InviteNotify,
29 LabeledResponse,
30 Monitor,
31 MultiPrefix,
32 Sasl,
33 Setname,
34 StandardReplies,
35 UserhostInNames,
36 BotMode,
37 UTF8Only,
38 StrictTransportSecurity,
39 WebIRC,
40 Chathistory,
41
42 MessageRedaction, AccountExtban, Metadata2, MessageTagsUnlimited,
49 Multiline, NoImplicitNames,
51 PreAway, ReadMarker, RelayMsg, ReplyDrafts,
55 TypingClient, WebSocket, ChannelRename, Persistence, ServerNameIndication, ClientTyping, ClientReply, ClientReact, LegionProtocolV1, #[deprecated(note = "Use LegionProtocolV1 instead")]
71 IronProtocolV1,
72
73 Custom(String),
75}
76
77impl Capability {
78 pub fn from_str(s: &str) -> Self {
80 match s {
81 "message-tags" => Capability::MessageTags,
83 "server-time" => Capability::ServerTime,
84 "account-notify" => Capability::AccountNotify,
85 "account-tag" => Capability::AccountTag,
86 "away-notify" => Capability::AwayNotify,
87 "batch" => Capability::Batch,
88 "cap-notify" => Capability::CapNotify,
89 "chghost" => Capability::ChgHost,
90 "echo-message" => Capability::EchoMessage,
91 "extended-join" => Capability::ExtendedJoin,
92 "invite-notify" => Capability::InviteNotify,
93 "labeled-response" => Capability::LabeledResponse,
94 "monitor" => Capability::Monitor,
95 "multi-prefix" => Capability::MultiPrefix,
96 "sasl" => Capability::Sasl,
97 "setname" => Capability::Setname,
98 "standard-replies" => Capability::StandardReplies,
99 "userhost-in-names" => Capability::UserhostInNames,
100 "bot" => Capability::BotMode,
101 "utf8only" => Capability::UTF8Only,
102 "sts" => Capability::StrictTransportSecurity,
103 "webirc" => Capability::WebIRC,
104 "chathistory" => Capability::Chathistory,
105
106 "draft/message-redaction" => Capability::MessageRedaction,
108 "account-extban" => Capability::AccountExtban,
109 "draft/metadata-2" => Capability::Metadata2,
110
111 "draft/message-tags-unlimited" => Capability::MessageTagsUnlimited,
113 "draft/multiline" => Capability::Multiline,
114 "draft/no-implicit-names" => Capability::NoImplicitNames,
115 "draft/pre-away" => Capability::PreAway,
116 "draft/read-marker" => Capability::ReadMarker,
117 "draft/relaymsg" => Capability::RelayMsg,
118 "draft/reply" => Capability::ReplyDrafts,
119 "draft/typing" => Capability::TypingClient,
120 "draft/websocket" => Capability::WebSocket,
121 "draft/channel-rename" => Capability::ChannelRename,
122 "draft/persistence" => Capability::Persistence,
123 "draft/sni" => Capability::ServerNameIndication,
124
125 "+typing" => Capability::ClientTyping,
127 "+draft/reply" => Capability::ClientReply,
128 "+draft/react" => Capability::ClientReact,
129
130 "+legion-protocol/v1" => Capability::LegionProtocolV1,
132
133 "+iron-protocol/v1" => Capability::IronProtocolV1,
135
136 other => Capability::Custom(other.to_string()),
137 }
138 }
139
140 pub fn as_str(&self) -> &str {
142 match self {
143 Capability::MessageTags => "message-tags",
145 Capability::ServerTime => "server-time",
146 Capability::AccountNotify => "account-notify",
147 Capability::AccountTag => "account-tag",
148 Capability::AwayNotify => "away-notify",
149 Capability::Batch => "batch",
150 Capability::CapNotify => "cap-notify",
151 Capability::ChgHost => "chghost",
152 Capability::EchoMessage => "echo-message",
153 Capability::ExtendedJoin => "extended-join",
154 Capability::InviteNotify => "invite-notify",
155 Capability::LabeledResponse => "labeled-response",
156 Capability::Monitor => "monitor",
157 Capability::MultiPrefix => "multi-prefix",
158 Capability::Sasl => "sasl",
159 Capability::Setname => "setname",
160 Capability::StandardReplies => "standard-replies",
161 Capability::UserhostInNames => "userhost-in-names",
162 Capability::BotMode => "bot",
163 Capability::UTF8Only => "utf8only",
164 Capability::StrictTransportSecurity => "sts",
165 Capability::WebIRC => "webirc",
166 Capability::Chathistory => "chathistory",
167
168 Capability::MessageRedaction => "draft/message-redaction",
170 Capability::AccountExtban => "account-extban",
171 Capability::Metadata2 => "draft/metadata-2",
172
173 Capability::MessageTagsUnlimited => "draft/message-tags-unlimited",
175 Capability::Multiline => "draft/multiline",
176 Capability::NoImplicitNames => "draft/no-implicit-names",
177 Capability::PreAway => "draft/pre-away",
178 Capability::ReadMarker => "draft/read-marker",
179 Capability::RelayMsg => "draft/relaymsg",
180 Capability::ReplyDrafts => "draft/reply",
181 Capability::TypingClient => "draft/typing",
182 Capability::WebSocket => "draft/websocket",
183 Capability::ChannelRename => "draft/channel-rename",
184 Capability::Persistence => "draft/persistence",
185 Capability::ServerNameIndication => "draft/sni",
186
187 Capability::ClientTyping => "+typing",
189 Capability::ClientReply => "+draft/reply",
190 Capability::ClientReact => "+draft/react",
191
192 Capability::LegionProtocolV1 => "+legion-protocol/v1",
194
195 Capability::IronProtocolV1 => "+iron-protocol/v1",
197
198 Capability::Custom(s) => s,
199 }
200 }
201
202 pub fn is_security_critical(&self) -> bool {
204 matches!(self,
205 Capability::Sasl |
206 Capability::StrictTransportSecurity |
207 Capability::AccountTag |
208 Capability::AccountNotify |
209 Capability::LegionProtocolV1 |
210 Capability::IronProtocolV1 )
212 }
213
214 pub fn is_draft(&self) -> bool {
216 self.as_str().starts_with("draft/") || matches!(self,
217 Capability::MessageRedaction |
218 Capability::MessageTagsUnlimited |
219 Capability::Multiline |
220 Capability::NoImplicitNames |
221 Capability::PreAway |
222 Capability::ReadMarker |
223 Capability::RelayMsg |
224 Capability::ReplyDrafts |
225 Capability::TypingClient |
226 Capability::WebSocket |
227 Capability::ChannelRename |
228 Capability::Persistence |
229 Capability::ServerNameIndication |
230 Capability::Metadata2
231 )
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct CapabilitySpec {
238 pub name: String,
239 pub value: Option<String>,
240 pub enabled: bool,
241}
242
243pub struct CapabilityHandler {
245 version: u16,
246 available_caps: HashMap<String, CapabilitySpec>,
247 requested_caps: Vec<String>,
248 enabled_caps: HashMap<String, CapabilitySpec>,
249 negotiation_complete: bool,
250 sts_policies: HashMap<String, StsPolicy>,
251}
252
253#[derive(Debug, Clone)]
255pub struct StsPolicy {
256 pub duration: Duration,
257 pub port: Option<u16>,
258 pub preload: bool,
259 pub expires_at: SystemTime,
260}
261
262impl CapabilityHandler {
263 pub fn new() -> Self {
265 Self {
266 version: 302,
267 available_caps: HashMap::new(),
268 requested_caps: Vec::new(),
269 enabled_caps: HashMap::new(),
270 negotiation_complete: false,
271 sts_policies: HashMap::new(),
272 }
273 }
274
275 pub fn set_version(&mut self, version: u16) {
277 self.version = version;
278 }
279
280 pub fn handle_cap_ls(&mut self, params: &[String]) -> Result<bool> {
282 if params.len() < 2 {
283 return Err(IronError::Parse("Invalid CAP LS response".to_string()));
284 }
285
286 let is_multiline = params.len() > 2 && params[1] == "*";
287 let caps_list = if is_multiline { ¶ms[2] } else { ¶ms[1] };
288
289 self.parse_capabilities(caps_list)?;
290
291 Ok(!is_multiline)
292 }
293
294 pub fn handle_cap_ack(&mut self, caps: &[String]) -> Result<()> {
296 for cap_param in caps {
297 for cap_name in cap_param.split_whitespace() {
299 let cap_name = cap_name.trim();
300 if !cap_name.is_empty() {
301 if let Some(cap) = self.available_caps.get(cap_name) {
302 let mut enabled_cap = cap.clone();
303 enabled_cap.enabled = true;
304 self.enabled_caps.insert(cap_name.to_string(), enabled_cap);
305 }
306 }
307 }
308 }
309 Ok(())
310 }
311
312 pub fn handle_cap_nak(&mut self, caps: &[String]) -> Result<()> {
314 for cap in caps {
315 if self.get_essential_capabilities().contains(&cap.as_str()) {
316 if matches!(cap.as_str(), "sasl" | "sts") {
317 return Err(IronError::SecurityViolation(
318 format!("Essential security capability rejected: {}", cap)
319 ));
320 }
321 }
322
323 self.requested_caps.retain(|c| c != cap);
324 }
325 Ok(())
326 }
327
328 pub fn handle_cap_new(&mut self, caps_str: &str) -> Result<Vec<String>> {
330 if self.version < 302 {
331 return Ok(Vec::new());
332 }
333
334 self.parse_capabilities(caps_str)?;
335
336 let mut new_requests = Vec::new();
337 for cap_name in caps_str.split_whitespace() {
338 let cap_name = cap_name.split('=').next().unwrap_or(cap_name);
339 if self.get_essential_capabilities().contains(&cap_name) {
340 new_requests.push(cap_name.to_string());
341 }
342 }
343
344 Ok(new_requests)
345 }
346
347 pub fn handle_cap_del(&mut self, caps: &[String]) -> Result<()> {
349 for cap in caps {
350 self.available_caps.remove(cap);
351 self.enabled_caps.remove(cap);
352 }
353 Ok(())
354 }
355
356 pub fn get_capabilities_to_request(&self) -> Vec<String> {
358 let mut caps_to_request = Vec::new();
359
360 for &cap_name in &self.get_essential_capabilities() {
361 if self.available_caps.contains_key(cap_name) {
362 caps_to_request.push(cap_name.to_string());
363 }
364 }
365
366 if let Some(sasl_cap) = self.available_caps.get("sasl") {
368 if let Err(_) = self.validate_sasl_mechanisms(sasl_cap) {
369 caps_to_request.retain(|c| c != "sasl");
370 }
371 }
372
373 caps_to_request
374 }
375
376 pub fn is_capability_enabled(&self, cap_name: &str) -> bool {
378 self.enabled_caps.contains_key(cap_name)
379 }
380
381 pub fn get_sasl_mechanisms(&self) -> Vec<String> {
383 if let Some(sasl_cap) = self.enabled_caps.get("sasl") {
384 if let Some(value) = &sasl_cap.value {
385 return value.split(',').map(|s| s.trim().to_string()).collect();
386 }
387 }
388 Vec::new()
389 }
390
391 pub fn set_negotiation_complete(&mut self) {
393 self.negotiation_complete = true;
394 }
395
396 pub fn is_negotiation_complete(&self) -> bool {
398 self.negotiation_complete
399 }
400
401 pub fn handle_sts_policy(&mut self, hostname: &str, cap_value: &str) -> Result<()> {
403 let mut duration = None;
404 let mut port = None;
405 let mut preload = false;
406
407 for param in cap_value.split(',') {
408 let parts: Vec<&str> = param.splitn(2, '=').collect();
409 match parts[0].trim() {
410 "duration" => {
411 if parts.len() > 1 {
412 duration = Some(Duration::from_secs(
413 parts[1].parse().map_err(|_| {
414 IronError::Parse("Invalid STS duration".to_string())
415 })?
416 ));
417 }
418 }
419 "port" => {
420 if parts.len() > 1 {
421 port = Some(parts[1].parse().map_err(|_| {
422 IronError::Parse("Invalid STS port".to_string())
423 })?);
424 }
425 }
426 "preload" => preload = true,
427 _ => {}
428 }
429 }
430
431 let duration = duration.ok_or_else(|| {
432 IronError::Parse("STS policy missing duration".to_string())
433 })?;
434
435 if duration.as_secs() == 0 {
436 self.sts_policies.remove(hostname);
437 return Ok(());
438 }
439
440 let policy = StsPolicy {
441 duration,
442 port,
443 preload,
444 expires_at: SystemTime::now() + duration,
445 };
446
447 self.sts_policies.insert(hostname.to_string(), policy);
448 Ok(())
449 }
450
451 pub fn should_upgrade_to_tls(&self, hostname: &str) -> Option<u16> {
453 if let Some(policy) = self.sts_policies.get(hostname) {
454 if SystemTime::now() < policy.expires_at {
455 return policy.port.or(Some(6697));
456 }
457 }
458 None
459 }
460
461 fn parse_capabilities(&mut self, caps_str: &str) -> Result<()> {
463 for cap_spec in caps_str.split_whitespace() {
464 if cap_spec.is_empty() {
465 continue;
466 }
467
468 let (name, value) = if let Some(eq_pos) = cap_spec.find('=') {
469 (&cap_spec[..eq_pos], Some(&cap_spec[eq_pos + 1..]))
470 } else {
471 (cap_spec, None)
472 };
473
474 if !self.is_valid_capability_name(name) {
475 return Err(IronError::SecurityViolation(
476 format!("Invalid capability name: {}", name)
477 ));
478 }
479
480 self.available_caps.insert(name.to_string(), CapabilitySpec {
481 name: name.to_string(),
482 value: value.map(String::from),
483 enabled: false,
484 });
485 }
486 Ok(())
487 }
488
489 pub fn get_essential_capabilities(&self) -> Vec<&str> {
491 vec![
492 "sasl",
494 "message-tags",
495 "server-time",
496 "batch",
497 "+draft/react",
499 "+draft/reply",
500 ]
501 }
502
503 fn validate_sasl_mechanisms(&self, sasl_cap: &CapabilitySpec) -> Result<()> {
505 if let Some(value) = &sasl_cap.value {
506 let mechanisms: Vec<&str> = value.split(',').collect();
507
508 let preferred_order = ["SCRAM-SHA-256", "EXTERNAL", "PLAIN"];
509
510 for &preferred in &preferred_order {
511 if mechanisms.iter().any(|m| m.trim() == preferred) {
512 return Ok(());
513 }
514 }
515
516 return Err(IronError::Auth(
517 "No supported SASL mechanisms".to_string()
518 ));
519 }
520 Ok(())
521 }
522
523 fn is_valid_capability_name(&self, name: &str) -> bool {
525 if name.is_empty() || name.len() > 64 {
526 return false;
527 }
528
529 if name.starts_with('-') {
530 return false;
531 }
532
533 if name.contains('/') {
534 let parts: Vec<&str> = name.split('/').collect();
535 if parts.len() != 2 {
536 return false;
537 }
538
539 if parts[0].contains('.') && !parts[0].ends_with(".com")
540 && !parts[0].ends_with(".org") && !parts[0].ends_with(".net")
541 && !parts[0].ends_with(".chat") && !parts[0].ends_with(".in") {
542 return false;
543 }
544 }
545
546 name.chars().all(|c| {
547 c.is_ascii_alphanumeric() ||
548 c == '-' || c == '/' || c == '.' || c == '_' || c == '+'
549 })
550 }
551}
552
553impl Default for CapabilityHandler {
554 fn default() -> Self {
555 Self::new()
556 }
557}
558
559pub struct CapabilitySet {
561 capabilities: HashSet<Capability>,
562}
563
564impl CapabilitySet {
565 pub fn new() -> Self {
567 let mut capabilities = HashSet::new();
568
569 capabilities.insert(Capability::MessageTags);
571 capabilities.insert(Capability::ServerTime);
572 capabilities.insert(Capability::AccountNotify);
573 capabilities.insert(Capability::AccountTag);
574 capabilities.insert(Capability::AwayNotify);
575 capabilities.insert(Capability::Batch);
576 capabilities.insert(Capability::CapNotify);
577 capabilities.insert(Capability::ChgHost);
578 capabilities.insert(Capability::EchoMessage);
579 capabilities.insert(Capability::ExtendedJoin);
580 capabilities.insert(Capability::InviteNotify);
581 capabilities.insert(Capability::LabeledResponse);
582 capabilities.insert(Capability::Monitor);
583 capabilities.insert(Capability::MultiPrefix);
584 capabilities.insert(Capability::Sasl);
585 capabilities.insert(Capability::Setname);
586 capabilities.insert(Capability::StandardReplies);
587 capabilities.insert(Capability::UserhostInNames);
588 capabilities.insert(Capability::BotMode);
589 capabilities.insert(Capability::UTF8Only);
590 capabilities.insert(Capability::StrictTransportSecurity);
591 capabilities.insert(Capability::Chathistory);
592
593 Self { capabilities }
594 }
595
596 pub fn stable_only() -> Self {
598 Self::new() }
600
601 pub fn bleeding_edge() -> Self {
603 let mut set = Self::new();
604
605 set.add(Capability::MessageRedaction);
607 set.add(Capability::AccountExtban);
608 set.add(Capability::Metadata2);
609
610 set.add(Capability::MessageTagsUnlimited);
612 set.add(Capability::Multiline);
613 set.add(Capability::NoImplicitNames);
614 set.add(Capability::PreAway);
615 set.add(Capability::ReadMarker);
616 set.add(Capability::RelayMsg);
617 set.add(Capability::ReplyDrafts);
618 set.add(Capability::TypingClient);
619 set.add(Capability::WebSocket);
620 set.add(Capability::ChannelRename);
621 set.add(Capability::Persistence);
622 set.add(Capability::ServerNameIndication);
623
624 set.add(Capability::ClientTyping);
626 set.add(Capability::ClientReply);
627 set.add(Capability::ClientReact);
628
629 set
630 }
631
632 pub fn supports(&self, cap: &Capability) -> bool {
634 self.capabilities.contains(cap)
635 }
636
637 pub fn add(&mut self, cap: Capability) {
639 self.capabilities.insert(cap);
640 }
641
642 pub fn remove(&mut self, cap: &Capability) -> bool {
644 self.capabilities.remove(cap)
645 }
646
647 pub fn to_string_list(&self) -> Vec<String> {
649 self.capabilities
650 .iter()
651 .map(|cap| cap.as_str().to_string())
652 .collect()
653 }
654
655 pub fn to_cap_ls_string(&self) -> String {
657 self.to_string_list().join(" ")
658 }
659}
660
661impl Default for CapabilitySet {
662 fn default() -> Self {
663 Self::new()
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn test_capability_parsing() {
673 let cap = Capability::from_str("message-tags");
674 assert_eq!(cap, Capability::MessageTags);
675 assert_eq!(cap.as_str(), "message-tags");
676 }
677
678 #[test]
679 fn test_draft_capability_detection() {
680 let draft_cap = Capability::from_str("draft/multiline");
681 assert!(draft_cap.is_draft());
682
683 let stable_cap = Capability::from_str("message-tags");
684 assert!(!stable_cap.is_draft());
685 }
686
687 #[test]
688 fn test_security_critical_detection() {
689 let sasl = Capability::from_str("sasl");
690 assert!(sasl.is_security_critical());
691
692 let tags = Capability::from_str("message-tags");
693 assert!(!tags.is_security_critical());
694 }
695
696 #[test]
697 fn test_capability_handler() {
698 let mut handler = CapabilityHandler::new();
699 let params = vec!["testnick".to_string(), "sasl=PLAIN message-tags".to_string()];
701
702 let complete = handler.handle_cap_ls(¶ms).unwrap();
703 assert!(complete); assert!(handler.available_caps.contains_key("sasl"));
705 assert!(handler.available_caps.contains_key("message-tags"));
706 }
707
708 #[test]
709 fn test_capability_set() {
710 let set = CapabilitySet::bleeding_edge();
711 assert!(set.supports(&Capability::MessageTags));
712 assert!(set.supports(&Capability::MessageRedaction));
713 assert!(set.supports(&Capability::Multiline));
714 }
715}