cw_controllers/
hooks.rs

1use schemars::JsonSchema;
2use std::fmt;
3use thiserror::Error;
4
5use cosmwasm_schema::cw_serde;
6use cosmwasm_std::{
7    attr, Addr, CustomQuery, Deps, DepsMut, MessageInfo, Response, StdError, StdResult, Storage,
8    SubMsg,
9};
10use cw_storage_plus::{Item, Namespace};
11
12use crate::admin::{Admin, AdminError};
13
14// this is copied from cw4
15// TODO: pull into utils as common dep
16#[cw_serde]
17pub struct HooksResponse {
18    pub hooks: Vec<String>,
19}
20
21#[derive(Error, Debug)]
22pub enum HookError {
23    #[error("{0}")]
24    Std(#[from] StdError),
25
26    #[error("{0}")]
27    Admin(#[from] AdminError),
28
29    #[error("Given address already registered as a hook")]
30    HookAlreadyRegistered {},
31
32    #[error("Given address not registered as a hook")]
33    HookNotRegistered {},
34}
35
36// store all hook addresses in one item. We cannot have many of them before the contract becomes unusable anyway.
37pub struct Hooks(Item<Vec<Addr>>);
38
39impl Hooks {
40    pub const fn new(storage_key: &'static str) -> Self {
41        Hooks(Item::new(storage_key))
42    }
43
44    pub fn new_dyn(storage_key: impl Into<Namespace>) -> Self {
45        Hooks(Item::new_dyn(storage_key))
46    }
47
48    pub fn add_hook(&self, storage: &mut dyn Storage, addr: Addr) -> Result<(), HookError> {
49        let mut hooks = self.0.may_load(storage)?.unwrap_or_default();
50        if !hooks.contains(&addr) {
51            hooks.push(addr);
52        } else {
53            return Err(HookError::HookAlreadyRegistered {});
54        }
55        Ok(self.0.save(storage, &hooks)?)
56    }
57
58    pub fn remove_hook(&self, storage: &mut dyn Storage, addr: Addr) -> Result<(), HookError> {
59        let mut hooks = self.0.load(storage)?;
60        if let Some(p) = hooks.iter().position(|x| x == addr) {
61            hooks.remove(p);
62        } else {
63            return Err(HookError::HookNotRegistered {});
64        }
65        Ok(self.0.save(storage, &hooks)?)
66    }
67
68    pub fn prepare_hooks<F: Fn(Addr) -> StdResult<SubMsg>>(
69        &self,
70        storage: &dyn Storage,
71        prep: F,
72    ) -> StdResult<Vec<SubMsg>> {
73        self.0
74            .may_load(storage)?
75            .unwrap_or_default()
76            .into_iter()
77            .map(prep)
78            .collect()
79    }
80
81    pub fn execute_add_hook<C, Q: CustomQuery>(
82        &self,
83        admin: &Admin,
84        deps: DepsMut<Q>,
85        info: MessageInfo,
86        addr: Addr,
87    ) -> Result<Response<C>, HookError>
88    where
89        C: Clone + fmt::Debug + PartialEq + JsonSchema,
90    {
91        admin.assert_admin(deps.as_ref(), &info.sender)?;
92        self.add_hook(deps.storage, addr.clone())?;
93
94        let attributes = vec![
95            attr("action", "add_hook"),
96            attr("hook", addr),
97            attr("sender", info.sender),
98        ];
99        Ok(Response::new().add_attributes(attributes))
100    }
101
102    pub fn execute_remove_hook<C, Q: CustomQuery>(
103        &self,
104        admin: &Admin,
105        deps: DepsMut<Q>,
106        info: MessageInfo,
107        addr: Addr,
108    ) -> Result<Response<C>, HookError>
109    where
110        C: Clone + fmt::Debug + PartialEq + JsonSchema,
111    {
112        admin.assert_admin(deps.as_ref(), &info.sender)?;
113        self.remove_hook(deps.storage, addr.clone())?;
114
115        let attributes = vec![
116            attr("action", "remove_hook"),
117            attr("hook", addr),
118            attr("sender", info.sender),
119        ];
120        Ok(Response::new().add_attributes(attributes))
121    }
122
123    pub fn query_hooks<Q: CustomQuery>(&self, deps: Deps<Q>) -> StdResult<HooksResponse> {
124        let hooks = self.0.may_load(deps.storage)?.unwrap_or_default();
125        let hooks = hooks.into_iter().map(String::from).collect();
126        Ok(HooksResponse { hooks })
127    }
128
129    // Return true if hook is in hooks
130    pub fn query_hook<Q: CustomQuery>(&self, deps: Deps<Q>, hook: String) -> StdResult<bool> {
131        Ok(self.query_hooks(deps)?.hooks.into_iter().any(|h| h == hook))
132    }
133}
134
135// TODO: add test coverage