use iri_string::types::UriString;
use std::{fmt::Display, str::FromStr};
use thiserror::Error;
use time::OffsetDateTime;
use crate::timestamp::TimeStamp;
const PREAMBLE: &str = " wants you to sign in with your Solana account:";
const URI_TAG: &str = "URI: ";
const VERSION_TAG: &str = "Version: ";
const CHAIN_TAG: &str = "Chain ID: ";
const NONCE_TAG: &str = "Nonce: ";
const IAT_TAG: &str = "Issued At: ";
const EXP_TAG: &str = "Expiration Time: ";
const NBF_TAG: &str = "Not Before: ";
const RID_TAG: &str = "Request ID: ";
const RES_TAG: &str = "Resources:";
const ERR_MSG_PREAMBLE: &str = "Missing or malformed Preamble Line";
const ERR_MSG_ADDR: &str = "Missing or malformed Address Line";
#[derive(Default, Debug, Clone)]
pub struct SiwsMessage {
pub domain: String,
pub address: String,
pub uri: Option<String>,
pub version: Option<String>,
pub statement: Option<String>,
pub nonce: Option<String>,
pub chain_id: Option<String>,
pub issued_at: Option<TimeStamp>,
pub expiration_time: Option<TimeStamp>,
pub not_before: Option<TimeStamp>,
pub request_id: Option<String>,
pub resources: Vec<UriString>,
}
#[derive(Default, Debug)]
pub struct ValidateOptions {
pub time: Option<OffsetDateTime>,
pub domain: Option<String>,
pub nonce: Option<String>,
}
#[derive(Debug, Error, PartialEq)]
pub enum ValidateError {
#[error("Domain mismatch.")]
Domain,
#[error("Message is expired.")]
ExpirationTime,
#[error("'Issued At' is before current time.")]
IssuedAt,
#[error("'Not Before' is before current time.")]
NotBefore,
}
#[derive(Error, Debug)]
pub enum ParseError {
#[error("Formatting Error: {0}")]
Format(&'static str),
#[error("Invalid TimeStamp: {0}")]
TimeStamp(#[from] time::Error),
#[error("Invalid URI: {0}")]
Uri(#[from] iri_string::validate::Error),
}
impl SiwsMessage {
pub fn validate(&self, options: ValidateOptions) -> Result<(), ValidateError> {
if let Some(domain) = options.domain {
if self.domain != domain {
return Err(ValidateError::Domain);
}
}
if let Some(options_nonce) = &options.nonce {
if let Some(message_nonce) = &self.nonce {
if message_nonce != options_nonce {
return Err(ValidateError::ExpirationTime);
}
}
}
if let Some(check_time) = options.time {
if let Some(issued_at) = &self.issued_at {
if issued_at > &check_time {
return Err(ValidateError::IssuedAt);
}
}
if let Some(expiration_time) = &self.expiration_time {
if expiration_time > &check_time {
return Err(ValidateError::ExpirationTime);
}
}
if let Some(not_before) = &self.not_before {
if not_before < &check_time {
return Err(ValidateError::NotBefore);
}
}
}
Ok(())
}
}
impl Display for SiwsMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", String::from(self))
}
}
impl From<&SiwsMessage> for String {
fn from(value: &SiwsMessage) -> Self {
let message_required: String = format!(
"{domain}{preamble}\n\
{address}",
domain = value.domain,
address = value.address,
preamble = PREAMBLE
);
let message_statement: String = match &value.statement {
Some(s) => format!("\n\n{s}"),
None => String::new(),
};
let uri = fmt_advanced_field(URI_TAG, &value.uri);
let version = fmt_advanced_field(VERSION_TAG, &value.version);
let chain_id = fmt_advanced_field(CHAIN_TAG, &value.chain_id);
let nonce = fmt_advanced_field(NONCE_TAG, &value.nonce);
let issued_at = fmt_advanced_field(IAT_TAG, &value.issued_at);
let expiration_time = fmt_advanced_field(EXP_TAG, &value.expiration_time);
let not_before = fmt_advanced_field(NBF_TAG, &value.not_before);
let request_id = fmt_advanced_field(RID_TAG, &value.request_id);
let resources = fmt_advanced_field_list(RES_TAG, &value.resources);
let advanced_fields: String = format!(
"\
{uri}\
{version}\
{chain_id}\
{nonce}\
{issued_at}\
{expiration_time}\
{not_before}\
{request_id}\
{resources}\
"
);
let advanced_fields: String = if !advanced_fields.is_empty() {
format!("\n{advanced_fields}")
} else {
String::new()
};
format!(
"\
{message_required}\
{message_statement}\
{advanced_fields}\
"
)
}
}
impl TryFrom<&[u8]> for SiwsMessage {
type Error = ParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let message_string: String = std::str::from_utf8(value)
.expect("Message should be valid UTF-8 byte array!")
.into();
SiwsMessage::from_str(&message_string)
}
}
impl TryFrom<&Vec<u8>> for SiwsMessage {
type Error = ParseError;
fn try_from(value: &Vec<u8>) -> Result<Self, Self::Error> {
let message_string: String = std::str::from_utf8(value)
.expect("Message should be valid UTF-8 byte array!")
.into();
SiwsMessage::from_str(&message_string)
}
}
impl FromStr for SiwsMessage {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut lines = s.split('\n');
let domain = lines
.next()
.and_then(|preamble| preamble.strip_suffix(PREAMBLE))
.map(|s| s.to_string())
.ok_or(ParseError::Format(ERR_MSG_PREAMBLE))?;
let address = lines
.next()
.map(|s| s.to_string())
.ok_or(ParseError::Format(ERR_MSG_ADDR))?;
lines.next();
let mut line = lines.next();
println!("Line: {:?}", line);
let statement = match line {
None => None,
Some("") => None,
Some(s) => {
if starts_with_advanced_field_tag(s) {
None
} else {
lines.next();
line = lines.next();
Some(s.to_string())
}
}
};
println!("Statement: {:?}", statement);
let uri = match tag_optional(URI_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(String::from(exp))
}
None => None,
};
let version = match tag_optional(VERSION_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(String::from(exp))
}
None => None,
};
let chain_id = match tag_optional(CHAIN_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(String::from(exp))
}
None => None,
};
let nonce = match tag_optional(NONCE_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(String::from(exp))
}
None => None,
};
let issued_at = match tag_optional(IAT_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(exp.parse()?)
}
None => None,
};
let expiration_time = match tag_optional(EXP_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(exp.parse()?)
}
None => None,
};
let not_before = match tag_optional(NBF_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(exp.parse()?)
}
None => None,
};
let request_id = match tag_optional(RID_TAG, line)? {
Some(exp) => {
line = lines.next();
Some(String::from(exp))
}
None => None,
};
let resources: Vec<UriString> = match line {
Some(RES_TAG) => lines.map(|s| parse_line("- ", Some(s))).collect(),
Some(_) => Err(ParseError::Format("Unexpected content")),
None => Ok(vec![]),
}?;
Ok(SiwsMessage {
domain,
address,
statement,
uri,
version,
chain_id,
nonce,
issued_at,
expiration_time,
not_before,
request_id,
resources,
})
}
}
fn fmt_advanced_field<T: std::fmt::Display>(name: &'static str, value: &Option<T>) -> String {
match value {
Some(v) => format!("\n{name}{v}"),
None => String::new(),
}
}
fn fmt_advanced_field_list(name: &'static str, value: &[UriString]) -> String {
if value.is_empty() {
return String::from("");
}
let field_name: String = format!("\n{name}");
let list_values = value
.iter()
.map(|x| format!("\n- {x}"))
.collect::<Vec<String>>()
.join("");
format!("{field_name}{list_values}")
}
fn parse_line<S: FromStr<Err = E>, E: Into<ParseError>>(
tag: &'static str,
line: Option<&str>,
) -> Result<S, ParseError> {
tagged(tag, line).and_then(|s| S::from_str(s).map_err(|e| e.into()))
}
fn tag_optional<'a>(
tag: &'static str,
line: Option<&'a str>,
) -> Result<Option<&'a str>, ParseError> {
match tagged(tag, line).map(Some) {
Err(ParseError::Format(t)) if t == tag => Ok(None),
r => r,
}
}
fn tagged<'a>(tag: &'static str, line: Option<&'a str>) -> Result<&'a str, ParseError> {
line.and_then(|l| l.strip_prefix(tag))
.ok_or(ParseError::Format(tag))
}
fn starts_with_advanced_field_tag(line: &str) -> bool {
line.starts_with(URI_TAG)
|| line.starts_with(VERSION_TAG)
|| line.starts_with(CHAIN_TAG)
|| line.starts_with(NONCE_TAG)
|| line.starts_with(IAT_TAG)
|| line.starts_with(EXP_TAG)
|| line.starts_with(NBF_TAG)
|| line.starts_with(RID_TAG)
|| line.starts_with(RES_TAG)
}
#[cfg(test)]
mod test {
use super::*;
use matches::assert_matches;
const TEST_DOMAIN: &str = "localhost";
const TEST_ADDR: &str = "0000000000000000000000000000000000000000";
#[test]
fn parse_throws_on_empty_message() {
let msg = "";
match SiwsMessage::from_str(msg) {
Ok(_) => panic!("Should return an error!"),
Err(e) => assert_matches!(e, ParseError::Format(ERR_MSG_PREAMBLE)),
};
}
#[test]
fn parse_throws_on_no_address() {
let msg = format!("{TEST_DOMAIN}{PREAMBLE}");
match SiwsMessage::from_str(&msg) {
Ok(_) => panic!("Should return an error!"),
Err(e) => assert_matches!(e, ParseError::Format(ERR_MSG_ADDR)),
};
}
#[test]
fn parse_throws_on_invalid_timestamp() {
let msg = format!("{TEST_DOMAIN}{PREAMBLE}\n{TEST_ADDR}\n\n{NBF_TAG}invalid");
match SiwsMessage::from_str(&msg) {
Ok(_) => panic!("Should return an error!"),
Err(e) => assert_matches!(e, ParseError::TimeStamp(_)),
};
}
#[test]
fn parse_throws_on_invalid_uri() {
let msg = format!("{TEST_DOMAIN}{PREAMBLE}\n{TEST_ADDR}\n\n{RES_TAG}\n- invalid");
match SiwsMessage::from_str(&msg) {
Ok(_) => panic!("Should return an error!"),
Err(e) => assert_matches!(e, ParseError::Uri(_)),
};
}
#[test]
fn minimal_parse() {
let msg = format!("{TEST_DOMAIN}{PREAMBLE}\n{TEST_ADDR}");
match SiwsMessage::from_str(&msg) {
Ok(m) => {
assert_eq!(TEST_DOMAIN, m.domain);
assert_eq!(TEST_ADDR, m.address);
}
Err(_) => panic!("Should not error!"),
};
}
#[test]
fn full_parse() -> Result<(), ParseError> {
let msg = include_str!("../tests/full_message.txt");
match SiwsMessage::from_str(msg) {
Ok(m) => {
assert_eq!(TEST_DOMAIN, m.domain);
assert_eq!(TEST_ADDR, m.address);
assert_eq!(Some("did:key:example".into()), m.uri);
assert_eq!(Some("1".into()), m.version);
assert_eq!(Some("testnet".into()), m.chain_id);
assert_eq!(Some("mynonce1".into()), m.nonce);
assert_eq!(
Some(TimeStamp::from_str("2022-06-21T12:00:00.000Z")?),
m.issued_at
);
}
Err(e) => {
panic!("{}", e);
}
};
Ok(())
}
}