use std::sync::Arc;
use crate::error::{ErrorStatus, Result};
use crate::handler::{MibHandler, RequestContext};
use crate::oid::Oid;
use crate::pdu::Pdu;
use crate::value::Value;
use crate::version::Version;
use super::Agent;
impl Agent {
pub(super) async fn handle_set(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
struct PendingSet<'a> {
handler: &'a Arc<dyn MibHandler>,
oid: Oid,
value: Value,
}
let mut pending: Vec<PendingSet> = Vec::with_capacity(pdu.varbinds.len());
for (index, vb) in pdu.varbinds.iter().enumerate() {
if let Some(ref vacm) = self.inner.vacm
&& !vacm.check_access(ctx.write_view.as_ref(), &vb.oid)
{
for p in pending.iter().rev() {
p.handler.free_set(ctx, &p.oid, &p.value).await;
}
let status = if ctx.version == Version::V1 {
ErrorStatus::NoSuchName
} else {
ErrorStatus::NoAccess
};
return Ok(pdu.to_error_response(status, (index + 1) as i32));
}
let handler = self.find_handler(&vb.oid);
if handler.is_none() {
for p in pending.iter().rev() {
p.handler.free_set(ctx, &p.oid, &p.value).await;
}
let status = if ctx.version == Version::V1 {
ErrorStatus::NoSuchName
} else {
ErrorStatus::NotWritable
};
return Ok(pdu.to_error_response(status, (index + 1) as i32));
}
let handler = handler.unwrap();
let result = handler.handler.test_set(ctx, &vb.oid, &vb.value).await;
if !result.is_ok() {
for p in pending.iter().rev() {
p.handler.free_set(ctx, &p.oid, &p.value).await;
}
let status = result.to_error_status();
let status = if ctx.version == Version::V1 {
status.to_v1()
} else {
status
};
return Ok(pdu.to_error_response(status, (index + 1) as i32));
}
pending.push(PendingSet {
handler: &handler.handler,
oid: vb.oid.clone(),
value: vb.value.clone(),
});
}
let mut committed: Vec<&PendingSet> = Vec::with_capacity(pending.len());
for (index, p) in pending.iter().enumerate() {
let result = p.handler.commit_set(ctx, &p.oid, &p.value).await;
if !result.is_ok() {
let mut undo_failed = false;
for c in committed.iter().rev() {
let undo_result = c.handler.undo_set(ctx, &c.oid, &c.value).await;
if !undo_result.is_ok() {
undo_failed = true;
tracing::warn!(target: "async_snmp::agent", { oid = %c.oid }, "undo_set failed during rollback");
}
}
let status = if undo_failed {
ErrorStatus::UndoFailed
} else {
ErrorStatus::CommitFailed
};
let status = if ctx.version == Version::V1 {
status.to_v1()
} else {
status
};
return Ok(pdu.to_error_response(status, (index + 1) as i32));
}
committed.push(p);
}
Ok(pdu.to_response())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use bytes::Bytes;
use crate::Oid;
use crate::agent::Agent;
use crate::handler::{
BoxFuture, GetNextResult, GetResult, MibHandler, RequestContext, SecurityModel, SetResult,
};
use crate::message::SecurityLevel;
use crate::oid;
use crate::pdu::{Pdu, PduType};
use crate::value::Value;
use crate::varbind::VarBind;
use crate::version::Version;
struct FreeSetTracker {
free_count: Arc<AtomicU32>,
}
impl MibHandler for FreeSetTracker {
fn get<'a>(&'a self, _ctx: &'a RequestContext, _oid: &'a Oid) -> BoxFuture<'a, GetResult> {
Box::pin(async { GetResult::NoSuchObject })
}
fn get_next<'a>(
&'a self,
_ctx: &'a RequestContext,
_oid: &'a Oid,
) -> BoxFuture<'a, GetNextResult> {
Box::pin(async { GetNextResult::EndOfMibView })
}
fn test_set<'a>(
&'a self,
_ctx: &'a RequestContext,
oid: &'a Oid,
_value: &'a Value,
) -> BoxFuture<'a, SetResult> {
Box::pin(async move {
if oid == &oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0) {
SetResult::WrongValue
} else {
SetResult::Ok
}
})
}
fn commit_set<'a>(
&'a self,
_ctx: &'a RequestContext,
_oid: &'a Oid,
_value: &'a Value,
) -> BoxFuture<'a, SetResult> {
Box::pin(async { SetResult::Ok })
}
fn free_set<'a>(
&'a self,
_ctx: &'a RequestContext,
_oid: &'a Oid,
_value: &'a Value,
) -> BoxFuture<'a, ()> {
self.free_count.fetch_add(1, Ordering::Relaxed);
Box::pin(async {})
}
}
fn test_ctx() -> RequestContext {
RequestContext {
source: "127.0.0.1:12345".parse().unwrap(),
version: Version::V2c,
security_model: SecurityModel::V2c,
security_name: Bytes::from_static(b"public"),
security_level: SecurityLevel::NoAuthNoPriv,
context_name: Bytes::new(),
request_id: 1,
pdu_type: PduType::SetRequest,
group_name: None,
read_view: None,
write_view: None,
msg_max_size: None,
}
}
#[tokio::test]
async fn test_free_set_called_on_test_failure() {
let free_count = Arc::new(AtomicU32::new(0));
let handler = Arc::new(FreeSetTracker {
free_count: free_count.clone(),
});
let agent = Agent::builder()
.bind("127.0.0.1:0")
.community(b"public")
.handler(oid!(1, 3, 6, 1, 4, 1, 99999), handler)
.build()
.await
.unwrap();
let ctx = test_ctx();
let pdu = Pdu {
pdu_type: PduType::SetRequest,
request_id: 1,
error_status: 0,
error_index: 0,
varbinds: vec![
VarBind::new(oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0), Value::Integer(1)),
VarBind::new(oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0), Value::Integer(2)),
],
};
let response = agent.dispatch_request(&ctx, &pdu).await.unwrap();
assert_eq!(response.error_index, 2);
assert_eq!(free_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_free_set_not_called_on_success() {
let free_count = Arc::new(AtomicU32::new(0));
let handler = Arc::new(FreeSetTracker {
free_count: free_count.clone(),
});
let agent = Agent::builder()
.bind("127.0.0.1:0")
.community(b"public")
.handler(oid!(1, 3, 6, 1, 4, 1, 99999), handler)
.build()
.await
.unwrap();
let ctx = test_ctx();
let pdu = Pdu {
pdu_type: PduType::SetRequest,
request_id: 1,
error_status: 0,
error_index: 0,
varbinds: vec![VarBind::new(
oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0),
Value::Integer(1),
)],
};
let response = agent.dispatch_request(&ctx, &pdu).await.unwrap();
assert_eq!(response.error_status, 0);
assert_eq!(free_count.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_free_set_not_called_when_first_varbind_fails() {
let free_count = Arc::new(AtomicU32::new(0));
let handler = Arc::new(FreeSetTracker {
free_count: free_count.clone(),
});
let agent = Agent::builder()
.bind("127.0.0.1:0")
.community(b"public")
.handler(oid!(1, 3, 6, 1, 4, 1, 99999), handler)
.build()
.await
.unwrap();
let ctx = test_ctx();
let pdu = Pdu {
pdu_type: PduType::SetRequest,
request_id: 1,
error_status: 0,
error_index: 0,
varbinds: vec![VarBind::new(
oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0),
Value::Integer(1),
)],
};
let response = agent.dispatch_request(&ctx, &pdu).await.unwrap();
assert_eq!(response.error_index, 1);
assert_eq!(free_count.load(Ordering::Relaxed), 0);
}
}