use crate::dispatcher::{DispatchError, SessionDispatcher};
use crate::message::Respondable;
use crate::oid::ParseObjectIdentifierError;
use num_bigint::BigInt;
use rasn::{ber, Decode, Encode};
use rasn_smi::v1::ObjectSyntax;
use rasn_snmp::v1;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::io;
use tokio::net::UdpSocket;
pub(super) trait VersionedSession {
type Message;
type Options;
}
pub struct SessionOptions<O> {
pub target: SocketAddr,
pub timeout: Duration,
pub snmp: O,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("invalid object identifier: {0}")]
InvalidObjectIdentifier(ParseObjectIdentifierError),
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("encoding error: {0}")]
EncodingError(ber::enc::EncodeError),
#[error("unexpected response: {0}")]
UnexpectedResponseType(String),
#[error("unsupported value type: {0:?}")]
UnsupportedValueType(ObjectSyntax),
}
impl From<DispatchError> for Error {
fn from(value: DispatchError) -> Self {
match value {
DispatchError::Io(err) => Error::Io(err),
DispatchError::Encoding(err) => Error::EncodingError(err),
}
}
}
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
pub struct Session<Version: VersionedSession> {
pub(super) inner: Arc<SessionInner<Version>>,
}
impl<V: VersionedSession> Clone for Session<V> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
pub(super) struct SessionInner<Version: VersionedSession> {
pub(super) dispatcher: SessionDispatcher<Version::Message>,
pub(super) request_id: AtomicU16,
pub(super) options: SessionOptions<Version::Options>,
}
impl<V: VersionedSession> Session<V> {
pub(super) fn next_request_id(&self) -> BigInt {
BigInt::from(self.inner.request_id.fetch_add(1, Ordering::SeqCst))
}
}
impl<V: VersionedSession> Session<V>
where
V::Message: Encode + Decode + Respondable + Send + Clone + 'static,
{
pub(super) async fn new(options: SessionOptions<V::Options>) -> io::Result<Session<V>> {
let addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let socket = Arc::new(UdpSocket::bind(&addr).await?);
let dispatcher = SessionDispatcher::new(socket, options.target, options.timeout);
dispatcher.spawn_receiver();
Ok(Self {
inner: Arc::new(SessionInner {
dispatcher,
options,
request_id: AtomicU16::new(1),
}),
})
}
pub(super) async fn send(&self, message: V::Message) -> Result<V::Message, Error> {
let token = self
.inner
.dispatcher
.send(self.inner.options.timeout, message)
.await?;
token
.await
.map_err(|_| Error::Io(io::Error::new(io::ErrorKind::TimedOut, "timeout")))
}
}