use crate::{ChannelType, IronError, Result};
use crate::utils::get_channel_type;
use crate::capabilities::Capability;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IronVersion {
V1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LegionVersion {
V1,
}
impl LegionVersion {
pub fn as_capability(&self) -> &'static str {
match self {
LegionVersion::V1 => "+legion-protocol/v1",
}
}
}
impl IronVersion {
pub fn as_capability(&self) -> &'static str {
match self {
IronVersion::V1 => "+iron-protocol/v1",
}
}
pub fn to_legion_version(&self) -> LegionVersion {
match self {
IronVersion::V1 => LegionVersion::V1,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IronNegotiationResult {
LegionCapable { version: LegionVersion },
IronCapable { version: IronVersion },
IrcFallback,
NotSupported,
}
#[derive(Debug, Clone)]
pub struct IronSession {
iron_version: Option<IronVersion>, legion_version: Option<LegionVersion>, encrypted_channels: Vec<String>,
negotiation_complete: bool,
}
impl IronSession {
pub fn new() -> Self {
Self {
iron_version: None,
legion_version: None,
encrypted_channels: Vec::new(),
negotiation_complete: false,
}
}
pub fn set_version(&mut self, version: IronVersion) {
self.iron_version = Some(version);
}
pub fn set_legion_version(&mut self, version: LegionVersion) {
self.legion_version = Some(version);
}
pub fn is_iron_active(&self) -> bool {
(self.legion_version.is_some() || self.iron_version.is_some()) && self.negotiation_complete
}
pub fn is_legion_active(&self) -> bool {
self.legion_version.is_some() && self.negotiation_complete
}
pub fn version(&self) -> Option<IronVersion> {
self.iron_version
}
pub fn legion_version(&self) -> Option<LegionVersion> {
self.legion_version
}
pub fn complete_negotiation(&mut self) {
self.negotiation_complete = true;
}
pub fn is_encrypted_channel(&self, channel: &str) -> bool {
self.encrypted_channels.iter().any(|c| c == channel)
}
pub fn add_encrypted_channel(&mut self, channel: String) {
if !self.encrypted_channels.contains(&channel) {
self.encrypted_channels.push(channel);
}
}
pub fn remove_encrypted_channel(&mut self, channel: &str) {
self.encrypted_channels.retain(|c| c != channel);
}
}
impl Default for IronSession {
fn default() -> Self {
Self::new()
}
}
pub struct IronChannelHandler;
impl IronChannelHandler {
pub fn can_join_channel(
channel: &str,
user_has_legion: bool,
server_has_legion: bool,
) -> Result<ChannelJoinResult> {
let channel_type = get_channel_type(channel);
match channel_type {
ChannelType::IrcGlobal | ChannelType::IrcLocal => {
Ok(ChannelJoinResult::Allowed)
}
ChannelType::LegionEncrypted => {
if user_has_legion && server_has_legion {
Ok(ChannelJoinResult::AllowedEncrypted)
} else {
Ok(ChannelJoinResult::Denied {
reason: IronChannelError::IncompatibleClient,
})
}
}
ChannelType::Invalid => Err(IronError::Parse(format!(
"Invalid channel name: {}",
channel
))),
}
}
pub fn generate_error_message(channel: &str, error: &IronChannelError) -> String {
match error {
IronChannelError::IncompatibleClient => {
format!(
"Cannot join encrypted channel {} - requires Legion Protocol support. \
Upgrade to a Legion-compatible client or ask channel admin to create \
a standard IRC channel (#{}) for IRC users.",
channel,
&channel[1..] )
}
IronChannelError::EncryptionRequired => {
format!(
"Channel {} requires end-to-end encryption. \
Please use a Legion Protocol-compatible client.",
channel
)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChannelJoinResult {
Allowed,
AllowedEncrypted,
Denied { reason: IronChannelError },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IronChannelError {
IncompatibleClient,
EncryptionRequired,
}
pub fn detect_legion_support(
client_caps: &[Capability],
server_caps: &[Capability],
) -> IronNegotiationResult {
let client_legion = client_caps
.iter()
.any(|cap| matches!(cap, Capability::LegionProtocolV1));
let server_legion = server_caps
.iter()
.any(|cap| matches!(cap, Capability::LegionProtocolV1));
let client_iron = client_caps
.iter()
.any(|cap| matches!(cap, Capability::IronProtocolV1));
let server_iron = server_caps
.iter()
.any(|cap| matches!(cap, Capability::IronProtocolV1));
match (client_legion, server_legion) {
(true, true) => IronNegotiationResult::LegionCapable {
version: LegionVersion::V1,
},
_ => {
match (client_iron, server_iron) {
(true, true) => IronNegotiationResult::IronCapable {
version: IronVersion::V1,
},
(true, false) | (false, true) => {
if client_legion || server_legion {
IronNegotiationResult::IrcFallback
} else {
IronNegotiationResult::IrcFallback
}
},
(false, false) => {
if client_legion || server_legion {
IronNegotiationResult::IrcFallback
} else {
IronNegotiationResult::NotSupported
}
},
}
}
}
}
#[deprecated(note = "Use detect_legion_support instead")]
pub fn detect_iron_support(
client_caps: &[Capability],
server_caps: &[Capability],
) -> IronNegotiationResult {
detect_legion_support(client_caps, server_caps)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_iron_version_capability() {
assert_eq!(IronVersion::V1.as_capability(), "+iron-protocol/v1");
assert_eq!(LegionVersion::V1.as_capability(), "+legion-protocol/v1");
assert_eq!(IronVersion::V1.to_legion_version(), LegionVersion::V1);
}
#[test]
fn test_channel_access_control() {
let result = IronChannelHandler::can_join_channel("#general", false, false).unwrap();
assert_eq!(result, ChannelJoinResult::Allowed);
let result = IronChannelHandler::can_join_channel("!encrypted", true, true).unwrap();
assert_eq!(result, ChannelJoinResult::AllowedEncrypted);
let result = IronChannelHandler::can_join_channel("!encrypted", false, true).unwrap();
assert!(matches!(
result,
ChannelJoinResult::Denied {
reason: IronChannelError::IncompatibleClient
}
));
}
#[test]
fn test_legion_detection() {
let client_caps = vec![Capability::LegionProtocolV1, Capability::MessageTags];
let server_caps = vec![Capability::LegionProtocolV1, Capability::Sasl];
let result = detect_legion_support(&client_caps, &server_caps);
assert_eq!(
result,
IronNegotiationResult::LegionCapable {
version: LegionVersion::V1
}
);
let client_caps = vec![Capability::IronProtocolV1, Capability::MessageTags];
let server_caps = vec![Capability::IronProtocolV1, Capability::Sasl];
let result = detect_legion_support(&client_caps, &server_caps);
assert_eq!(
result,
IronNegotiationResult::IronCapable {
version: IronVersion::V1
}
);
let client_caps = vec![Capability::MessageTags];
let result = detect_legion_support(&client_caps, &server_caps);
assert_eq!(result, IronNegotiationResult::IrcFallback);
#[allow(deprecated)]
let result = detect_iron_support(&client_caps, &server_caps);
assert_eq!(result, IronNegotiationResult::IrcFallback);
}
#[test]
fn test_legion_session() {
let mut session = IronSession::new();
assert!(!session.is_iron_active());
assert!(!session.is_legion_active());
session.set_version(IronVersion::V1);
session.complete_negotiation();
assert!(session.is_iron_active());
assert!(!session.is_legion_active());
assert_eq!(session.version(), Some(IronVersion::V1));
let mut legion_session = IronSession::new();
legion_session.set_legion_version(LegionVersion::V1);
legion_session.complete_negotiation();
assert!(legion_session.is_iron_active()); assert!(legion_session.is_legion_active());
assert_eq!(legion_session.legion_version(), Some(LegionVersion::V1));
session.add_encrypted_channel("!secure".to_string());
assert!(session.is_encrypted_channel("!secure"));
assert!(!session.is_encrypted_channel("!other"));
}
}