use crate::dns::Message;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug)]
pub struct Context {
request: Message,
response: Option<Arc<Message>>,
metadata: HashMap<String, Box<dyn Any + Send + Sync>>,
}
impl Context {
pub fn new(request: Message) -> Self {
Self {
request,
response: None,
metadata: HashMap::new(),
}
}
pub fn request(&self) -> &Message {
&self.request
}
pub fn request_mut(&mut self) -> &mut Message {
&mut self.request
}
pub fn response(&self) -> Option<&Message> {
self.response.as_deref()
}
pub fn response_mut(&mut self) -> Option<&mut Message> {
self.response.as_mut().map(Arc::make_mut)
}
pub fn set_response(&mut self, response: Option<Message>) {
self.response = response.map(Arc::new);
}
pub fn set_response_arc(&mut self, response: Option<Arc<Message>>) {
self.response = response;
}
pub fn take_response(&mut self) -> Option<Message> {
self.response.take().map(|arc| match Arc::try_unwrap(arc) {
Ok(msg) => msg,
Err(shared) => (*shared).clone(),
})
}
pub fn has_response(&self) -> bool {
self.response.is_some()
}
pub fn set_metadata<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
self.metadata.insert(key.into(), Box::new(value));
}
pub fn get_metadata<T: Any>(&self, key: &str) -> Option<&T> {
self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
}
pub fn remove_metadata(&mut self, key: &str) -> bool {
self.metadata.remove(key).is_some()
}
pub fn has_metadata(&self, key: &str) -> bool {
self.metadata.contains_key(key)
}
pub fn clear_metadata(&mut self) {
self.metadata.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::{Question, RecordClass, RecordType};
#[test]
fn test_context_creation() {
let request = Message::new();
let ctx = Context::new(request);
assert!(!ctx.has_response());
assert_eq!(ctx.request().question_count(), 0);
}
#[test]
fn test_request_access() {
let mut request = Message::new();
request.set_id(1234);
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(request);
assert_eq!(ctx.request().id(), 1234);
assert_eq!(ctx.request().question_count(), 1);
}
#[test]
fn test_response_handling() {
let request = Message::new();
let mut ctx = Context::new(request);
assert!(!ctx.has_response());
assert!(ctx.response().is_none());
let mut response = Message::new();
response.set_id(5678);
ctx.set_response(Some(response));
assert!(ctx.has_response());
assert_eq!(ctx.response().unwrap().id(), 5678);
let taken = ctx.take_response();
assert!(taken.is_some());
assert!(!ctx.has_response());
}
#[test]
fn test_metadata() {
let request = Message::new();
let mut ctx = Context::new(request);
ctx.set_metadata("string", "test");
ctx.set_metadata("number", 42i32);
ctx.set_metadata("bool", true);
assert_eq!(ctx.get_metadata::<&str>("string"), Some(&"test"));
assert_eq!(ctx.get_metadata::<i32>("number"), Some(&42));
assert_eq!(ctx.get_metadata::<bool>("bool"), Some(&true));
assert!(ctx.get_metadata::<i64>("number").is_none());
assert!(ctx.has_metadata("string"));
assert!(!ctx.has_metadata("nonexistent"));
assert!(ctx.remove_metadata("string"));
assert!(!ctx.has_metadata("string"));
assert!(!ctx.remove_metadata("string"));
}
#[test]
fn test_clear_metadata() {
let request = Message::new();
let mut ctx = Context::new(request);
ctx.set_metadata("key1", "value1");
ctx.set_metadata("key2", 123);
assert!(ctx.has_metadata("key1"));
assert!(ctx.has_metadata("key2"));
ctx.clear_metadata();
assert!(!ctx.has_metadata("key1"));
assert!(!ctx.has_metadata("key2"));
}
#[test]
fn test_mutable_request() {
let request = Message::new();
let mut ctx = Context::new(request);
ctx.request_mut().set_id(9999);
assert_eq!(ctx.request().id(), 9999);
ctx.request_mut().add_question(Question::new(
"test.com".to_string(),
RecordType::A,
RecordClass::IN,
));
assert_eq!(ctx.request().question_count(), 1);
}
#[test]
fn test_mutable_response() {
let request = Message::new();
let mut ctx = Context::new(request);
let response = Message::new();
ctx.set_response(Some(response));
if let Some(resp) = ctx.response_mut() {
resp.set_id(1111);
}
assert_eq!(ctx.response().unwrap().id(), 1111);
}
}