1use crate::{
4 discover::{AsDiscovered, DiscoveredAdvisory},
5 retrieve::{AsRetrieved, RetrievalContext, RetrievedAdvisory, RetrievedVisitor},
6 source::Source,
7};
8use std::{
9 fmt::{Debug, Display, Formatter},
10 future::Future,
11 marker::PhantomData,
12 ops::{Deref, DerefMut},
13};
14use url::Url;
15use walker_common::{
16 retrieve::RetrievalError,
17 utils::{openpgp::PublicKey, url::Urlify},
18 validate::{digest::validate_digest, openpgp, ValidationOptions},
19};
20
21#[derive(Clone, Debug)]
28pub struct ValidatedAdvisory {
29 pub retrieved: RetrievedAdvisory,
31}
32
33impl Urlify for ValidatedAdvisory {
34 fn url(&self) -> &Url {
35 &self.url
36 }
37
38 fn relative_base_and_url(&self) -> Option<(&Url, String)> {
39 self.retrieved.relative_base_and_url()
40 }
41}
42
43impl Deref for ValidatedAdvisory {
44 type Target = RetrievedAdvisory;
45
46 fn deref(&self) -> &Self::Target {
47 &self.retrieved
48 }
49}
50
51impl DerefMut for ValidatedAdvisory {
52 fn deref_mut(&mut self) -> &mut Self::Target {
53 &mut self.retrieved
54 }
55}
56
57impl AsDiscovered for ValidatedAdvisory {
58 fn as_discovered(&self) -> &DiscoveredAdvisory {
59 &self.discovered
60 }
61}
62
63impl AsRetrieved for ValidatedAdvisory {
64 fn as_retrieved(&self) -> &RetrievedAdvisory {
65 &self.retrieved
66 }
67}
68
69#[derive(Debug, thiserror::Error)]
70pub enum ValidationError<S: Source> {
71 Retrieval(RetrievalError<DiscoveredAdvisory, S>),
72 DigestMismatch {
73 expected: String,
74 actual: String,
75 retrieved: RetrievedAdvisory,
76 },
77 Signature {
78 error: anyhow::Error,
79 retrieved: RetrievedAdvisory,
80 },
81}
82
83impl<S: Source + Debug> AsDiscovered for ValidationError<S> {
84 fn as_discovered(&self) -> &DiscoveredAdvisory {
85 match self {
86 Self::Retrieval(err) => err.discovered(),
87 Self::DigestMismatch { retrieved, .. } => retrieved.as_discovered(),
88 Self::Signature { retrieved, .. } => retrieved.as_discovered(),
89 }
90 }
91}
92
93impl<S: Source> Urlify for ValidationError<S> {
94 fn url(&self) -> &Url {
95 match self {
96 Self::Retrieval(err) => err.url(),
97 Self::DigestMismatch { retrieved, .. } => &retrieved.url,
98 Self::Signature { retrieved, .. } => &retrieved.url,
99 }
100 }
101}
102
103impl<S: Source> Display for ValidationError<S> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Retrieval(err) => write!(f, "Retrieval error: {err}"),
107 Self::DigestMismatch {
108 expected,
109 actual,
110 retrieved: _,
111 } => write!(
112 f,
113 "Digest mismatch - expected: {expected}, actual: {actual}",
114 ),
115 Self::Signature {
116 error,
117 retrieved: _,
118 } => {
119 write!(f, "Invalid signature: {error}",)
120 }
121 }
122 }
123}
124
125pub struct ValidationContext<'c> {
126 pub retrieval: &'c RetrievalContext<'c>,
127}
128
129impl<'c> Deref for ValidationContext<'c> {
130 type Target = RetrievalContext<'c>;
131
132 fn deref(&self) -> &Self::Target {
133 self.retrieval
134 }
135}
136
137pub trait ValidatedVisitor<S: Source> {
138 type Error: Display + Debug;
139 type Context;
140
141 fn visit_context(
142 &self,
143 context: &ValidationContext,
144 ) -> impl Future<Output = Result<Self::Context, Self::Error>>;
145
146 fn visit_advisory(
147 &self,
148 context: &Self::Context,
149 result: Result<ValidatedAdvisory, ValidationError<S>>,
150 ) -> impl Future<Output = Result<(), Self::Error>>;
151}
152
153impl<F, E, Fut, S> ValidatedVisitor<S> for F
154where
155 F: Fn(Result<ValidatedAdvisory, ValidationError<S>>) -> Fut,
156 Fut: Future<Output = Result<(), E>>,
157 E: Display + Debug,
158 S: Source,
159{
160 type Error = E;
161 type Context = ();
162
163 async fn visit_context(
164 &self,
165 _context: &ValidationContext<'_>,
166 ) -> Result<Self::Context, Self::Error> {
167 Ok(())
168 }
169
170 async fn visit_advisory(
171 &self,
172 _context: &Self::Context,
173 result: Result<ValidatedAdvisory, ValidationError<S>>,
174 ) -> Result<(), Self::Error> {
175 self(result).await
176 }
177}
178
179pub struct ValidationVisitor<V, S>
180where
181 V: ValidatedVisitor<S>,
182 S: Source,
183{
184 visitor: V,
185 options: ValidationOptions,
186 _marker: PhantomData<S>,
187}
188
189enum ValidationProcessError<S: Source> {
190 Proceed(ValidationError<S>),
192 #[allow(unused)]
194 Abort(anyhow::Error),
195}
196
197#[derive(Debug, thiserror::Error)]
198pub enum Error<VE>
199where
200 VE: Display + Debug,
201{
202 #[error("{0}")]
203 Visitor(VE),
204 #[error("Severe validation error: {0}")]
205 Validation(anyhow::Error),
206}
207
208impl<V, S> ValidationVisitor<V, S>
209where
210 V: ValidatedVisitor<S>,
211 S: Source,
212{
213 pub fn new(visitor: V) -> Self {
214 Self {
215 visitor,
216 options: Default::default(),
217 _marker: Default::default(),
218 }
219 }
220
221 pub fn with_options(mut self, options: impl Into<ValidationOptions>) -> Self {
222 self.options = options.into();
223 self
224 }
225
226 async fn validate(
230 &self,
231 context: &InnerValidationContext<V::Context>,
232 retrieved: RetrievedAdvisory,
233 ) -> Result<ValidatedAdvisory, ValidationProcessError<S>> {
234 if let Err((expected, actual)) = validate_digest(&retrieved.sha256) {
235 return Err(ValidationProcessError::Proceed(
236 ValidationError::DigestMismatch {
237 expected,
238 actual,
239 retrieved,
240 },
241 ));
242 }
243 if let Err((expected, actual)) = validate_digest(&retrieved.sha512) {
244 return Err(ValidationProcessError::Proceed(
245 ValidationError::DigestMismatch {
246 expected,
247 actual,
248 retrieved,
249 },
250 ));
251 }
252
253 if let Some(signature) = &retrieved.signature {
254 match openpgp::validate_signature(
255 &self.options,
256 &context.keys,
257 signature,
258 &retrieved.data,
259 ) {
260 Ok(()) => Ok(ValidatedAdvisory { retrieved }),
261 Err(error) => Err(ValidationProcessError::Proceed(
262 ValidationError::Signature { error, retrieved },
263 )),
264 }
265 } else {
266 Ok(ValidatedAdvisory { retrieved })
267 }
268 }
269}
270
271pub struct InnerValidationContext<VC> {
272 context: VC,
273 keys: Vec<PublicKey>,
274}
275
276impl<V, S> RetrievedVisitor<S> for ValidationVisitor<V, S>
277where
278 V: ValidatedVisitor<S>,
279 S: Source,
280{
281 type Error = Error<V::Error>;
282 type Context = InnerValidationContext<V::Context>;
283
284 async fn visit_context(
285 &self,
286 context: &RetrievalContext<'_>,
287 ) -> Result<Self::Context, Self::Error> {
288 let keys = context.keys.clone();
289
290 let context = self
291 .visitor
292 .visit_context(&ValidationContext { retrieval: context })
293 .await
294 .map_err(Error::Visitor)?;
295
296 Ok(Self::Context { context, keys })
297 }
298
299 async fn visit_advisory(
300 &self,
301 context: &Self::Context,
302 outcome: Result<RetrievedAdvisory, RetrievalError<DiscoveredAdvisory, S>>,
303 ) -> Result<(), Self::Error> {
304 match outcome {
305 Ok(advisory) => {
306 let result = match self.validate(context, advisory).await {
307 Ok(result) => Ok(result),
308 Err(ValidationProcessError::Proceed(err)) => Err(err),
309 Err(ValidationProcessError::Abort(err)) => return Err(Error::Validation(err)),
310 };
311 self.visitor
312 .visit_advisory(&context.context, result)
313 .await
314 .map_err(Error::Visitor)?
315 }
316 Err(err) => self
317 .visitor
318 .visit_advisory(&context.context, Err(ValidationError::Retrieval(err)))
319 .await
320 .map_err(Error::Visitor)?,
321 }
322
323 Ok(())
324 }
325}