use std::str::FromStr;
use chrono::Utc;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub const DEFAULT_STATE_HEADER: &str = "Otoroshi-State";
pub const DEFAULT_STATE_RESP_HEADER: &str = "Otoroshi-State-Resp";
pub const OTOROSHI_ISSUER: &str = "Otoroshi";
pub const DEFAULT_TOKEN_EXPIRY_SECONDS: i64 = 30;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Algorithm {
HS256,
HS384,
#[default]
HS512,
}
impl FromStr for Algorithm {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s.to_uppercase().as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
_ => Algorithm::HS512,
})
}
}
impl Algorithm {
fn as_jsonwebtoken(self) -> jsonwebtoken::Algorithm {
match self {
Algorithm::HS256 => jsonwebtoken::Algorithm::HS256,
Algorithm::HS384 => jsonwebtoken::Algorithm::HS384,
Algorithm::HS512 => jsonwebtoken::Algorithm::HS512,
}
}
}
#[derive(Debug, Error)]
pub enum ProtocolError {
#[error("JWT verification failed: {0}")]
VerificationFailed(#[from] jsonwebtoken::errors::Error),
#[error("Failed to create response token: {0}")]
EncodingFailed(jsonwebtoken::errors::Error),
}
#[derive(Debug, Deserialize)]
struct ChallengeClaims {
state: String,
}
#[derive(Debug, Serialize)]
struct ResponseClaims {
#[serde(rename = "state-resp")]
state_resp: String,
aud: String,
iat: i64,
nbf: i64,
exp: i64,
}
#[derive(Debug, Clone)]
pub struct OtoroshiProtocol {
pub algo_in: Algorithm,
pub secret_in: Vec<u8>,
pub algo_out: Algorithm,
pub secret_out: Vec<u8>,
pub ttl: i64,
}
impl OtoroshiProtocol {
pub fn new(secret: &[u8], algorithm: Algorithm) -> Self {
Self {
algo_in: algorithm,
secret_in: secret.to_vec(),
algo_out: algorithm,
secret_out: secret.to_vec(),
ttl: DEFAULT_TOKEN_EXPIRY_SECONDS,
}
}
pub fn new_with_ttl(secret: &[u8], algorithm: Algorithm, ttl: i64) -> Self {
Self {
algo_in: algorithm,
secret_in: secret.to_vec(),
algo_out: algorithm,
secret_out: secret.to_vec(),
ttl,
}
}
pub fn new_asymmetric(
secret_in: &[u8],
algo_in: Algorithm,
secret_out: &[u8],
algo_out: Algorithm,
) -> Self {
Self {
algo_in,
secret_in: secret_in.to_vec(),
algo_out,
secret_out: secret_out.to_vec(),
ttl: DEFAULT_TOKEN_EXPIRY_SECONDS,
}
}
pub fn process_v1(&self, state: &str) -> String {
state.to_string()
}
pub fn process_v2(&self, token: &str) -> Result<String, ProtocolError> {
let state = self.verify_challenge(token)?;
self.create_response_token(&state)
}
pub fn verify_challenge(&self, token: &str) -> Result<String, ProtocolError> {
let mut validation = Validation::new(self.algo_in.as_jsonwebtoken());
validation.set_required_spec_claims(&["exp", "iss"]);
validation.set_issuer(&[OTOROSHI_ISSUER]);
validation.validate_aud = false;
validation.leeway = 10;
let token_data = decode::<ChallengeClaims>(
token,
&DecodingKey::from_secret(&self.secret_in),
&validation,
)?;
Ok(token_data.claims.state)
}
pub fn create_response_token(&self, state: &str) -> Result<String, ProtocolError> {
let now = Utc::now().timestamp();
let claims = ResponseClaims {
state_resp: state.to_string(),
aud: OTOROSHI_ISSUER.to_string(),
iat: now,
nbf: now,
exp: now + self.ttl,
};
encode(
&Header::new(self.algo_out.as_jsonwebtoken()),
&claims,
&EncodingKey::from_secret(&self.secret_out),
)
.map_err(ProtocolError::EncodingFailed)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &[u8] = b"test-secret-key-for-testing";
#[test]
fn test_algorithm_from_str_hs256() {
assert_eq!(Algorithm::from_str("HS256").unwrap(), Algorithm::HS256);
assert_eq!(Algorithm::from_str("hs256").unwrap(), Algorithm::HS256);
}
#[test]
fn test_algorithm_from_str_hs384() {
assert_eq!(Algorithm::from_str("HS384").unwrap(), Algorithm::HS384);
assert_eq!(Algorithm::from_str("hs384").unwrap(), Algorithm::HS384);
}
#[test]
fn test_algorithm_from_str_hs512() {
assert_eq!(Algorithm::from_str("HS512").unwrap(), Algorithm::HS512);
assert_eq!(Algorithm::from_str("hs512").unwrap(), Algorithm::HS512);
}
#[test]
fn test_algorithm_from_str_defaults_to_hs512() {
assert_eq!(Algorithm::from_str("unknown").unwrap(), Algorithm::HS512);
assert_eq!(Algorithm::from_str("").unwrap(), Algorithm::HS512);
}
#[test]
fn test_algorithm_default() {
assert_eq!(Algorithm::default(), Algorithm::HS512);
}
#[test]
fn test_protocol_new() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS256);
assert_eq!(protocol.algo_in, Algorithm::HS256);
assert_eq!(protocol.algo_out, Algorithm::HS256);
assert_eq!(protocol.secret_in, TEST_SECRET);
assert_eq!(protocol.secret_out, TEST_SECRET);
assert_eq!(protocol.ttl, DEFAULT_TOKEN_EXPIRY_SECONDS);
}
#[test]
fn test_protocol_new_with_ttl() {
let protocol = OtoroshiProtocol::new_with_ttl(TEST_SECRET, Algorithm::HS384, 60);
assert_eq!(protocol.algo_in, Algorithm::HS384);
assert_eq!(protocol.ttl, 60);
}
#[test]
fn test_protocol_new_asymmetric() {
let secret_in = b"secret-in";
let secret_out = b"secret-out";
let protocol = OtoroshiProtocol::new_asymmetric(
secret_in,
Algorithm::HS256,
secret_out,
Algorithm::HS512,
);
assert_eq!(protocol.algo_in, Algorithm::HS256);
assert_eq!(protocol.algo_out, Algorithm::HS512);
assert_eq!(protocol.secret_in, secret_in);
assert_eq!(protocol.secret_out, secret_out);
}
#[test]
fn test_process_v1_echoes_state() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
assert_eq!(protocol.process_v1("test-state"), "test-state");
assert_eq!(protocol.process_v1(""), "");
assert_eq!(
protocol.process_v1("special-chars-!@#$%"),
"special-chars-!@#$%"
);
}
#[test]
fn test_create_response_token_success() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
let token = protocol.create_response_token("my-state").unwrap();
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
}
#[test]
fn test_create_response_token_can_be_decoded() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
let token = protocol.create_response_token("my-state").unwrap();
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS512);
validation.set_required_spec_claims::<&str>(&[]);
validation.set_audience(&[OTOROSHI_ISSUER]);
#[derive(Debug, Deserialize)]
struct TestClaims {
#[serde(rename = "state-resp")]
state_resp: String,
aud: String,
iat: i64,
#[allow(dead_code)]
nbf: i64,
exp: i64,
}
let decoded =
decode::<TestClaims>(&token, &DecodingKey::from_secret(TEST_SECRET), &validation)
.unwrap();
assert_eq!(decoded.claims.state_resp, "my-state");
assert_eq!(decoded.claims.aud, OTOROSHI_ISSUER);
assert!(decoded.claims.exp > decoded.claims.iat);
}
#[test]
fn test_create_response_token_respects_ttl() {
let protocol = OtoroshiProtocol::new_with_ttl(TEST_SECRET, Algorithm::HS512, 120);
let token = protocol.create_response_token("state").unwrap();
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS512);
validation.set_required_spec_claims::<&str>(&[]);
validation.set_audience(&[OTOROSHI_ISSUER]);
#[derive(Debug, Deserialize)]
struct TestClaims {
iat: i64,
exp: i64,
}
let decoded =
decode::<TestClaims>(&token, &DecodingKey::from_secret(TEST_SECRET), &validation)
.unwrap();
assert_eq!(decoded.claims.exp - decoded.claims.iat, 120);
}
#[test]
fn test_verify_challenge_valid_token() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
iat: i64,
exp: i64,
}
let now = Utc::now().timestamp();
let claims = Challenge {
state: "challenge-state".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
iat: now,
exp: now + 60,
};
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&claims,
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let state = protocol.verify_challenge(&token).unwrap();
assert_eq!(state, "challenge-state");
}
#[test]
fn test_verify_challenge_invalid_signature() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
exp: i64,
}
let claims = Challenge {
state: "state".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
exp: Utc::now().timestamp() + 60,
};
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&claims,
&EncodingKey::from_secret(b"wrong-secret"),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_err());
}
#[test]
fn test_verify_challenge_malformed_token() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
let result = protocol.verify_challenge("not-a-valid-jwt");
assert!(result.is_err());
}
#[test]
fn test_verify_challenge_missing_issuer() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
exp: i64,
}
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "test".to_string(),
exp: Utc::now().timestamp() + 60,
},
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_err());
}
#[test]
fn test_verify_challenge_wrong_issuer() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
exp: i64,
}
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "test".to_string(),
iss: "NotOtoroshi".to_string(),
exp: Utc::now().timestamp() + 60,
},
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_err());
}
#[test]
fn test_verify_challenge_missing_expiration() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
}
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "test".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
},
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_err());
}
#[test]
fn test_process_v2_full_flow() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
iat: i64,
exp: i64,
}
let now = Utc::now().timestamp();
let challenge_token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "roundtrip-state".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
iat: now,
exp: now + 60,
},
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let response_token = protocol.process_v2(&challenge_token).unwrap();
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS512);
validation.set_required_spec_claims::<&str>(&[]);
validation.set_audience(&[OTOROSHI_ISSUER]);
#[derive(Deserialize)]
struct Response {
#[serde(rename = "state-resp")]
state_resp: String,
}
let decoded = decode::<Response>(
&response_token,
&DecodingKey::from_secret(TEST_SECRET),
&validation,
)
.unwrap();
assert_eq!(decoded.claims.state_resp, "roundtrip-state");
}
#[test]
fn test_create_and_verify_with_hs256() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS256);
let token = protocol.create_response_token("state").unwrap();
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
validation.set_required_spec_claims::<&str>(&[]);
validation.set_audience(&[OTOROSHI_ISSUER]);
let result = decode::<serde_json::Value>(
&token,
&DecodingKey::from_secret(TEST_SECRET),
&validation,
);
assert!(result.is_ok());
}
#[test]
fn test_create_and_verify_with_hs384() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS384);
let token = protocol.create_response_token("state").unwrap();
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS384);
validation.set_required_spec_claims::<&str>(&[]);
validation.set_audience(&[OTOROSHI_ISSUER]);
let result = decode::<serde_json::Value>(
&token,
&DecodingKey::from_secret(TEST_SECRET),
&validation,
);
assert!(result.is_ok());
}
#[test]
fn test_verify_challenge_expired_within_leeway() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
exp: i64,
}
let now = Utc::now().timestamp();
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "leeway-test".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
exp: now - 5, },
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "leeway-test");
}
#[test]
fn test_verify_challenge_expired_beyond_leeway() {
let protocol = OtoroshiProtocol::new(TEST_SECRET, Algorithm::HS512);
#[derive(Serialize)]
struct Challenge {
state: String,
iss: String,
exp: i64,
}
let now = Utc::now().timestamp();
let token = encode(
&Header::new(jsonwebtoken::Algorithm::HS512),
&Challenge {
state: "expired-test".to_string(),
iss: OTOROSHI_ISSUER.to_string(),
exp: now - 15, },
&EncodingKey::from_secret(TEST_SECRET),
)
.unwrap();
let result = protocol.verify_challenge(&token);
assert!(result.is_err());
}
}