1use crate::{
4 discover::{DiscoveredContext, DiscoveredSbom, DiscoveredVisitor},
5 source::Source,
6};
7use bytes::Bytes;
8use sha2::{Sha256, Sha512};
9use std::{
10 fmt::Debug,
11 future::Future,
12 ops::{Deref, DerefMut},
13};
14use url::Url;
15use walker_common::{
16 retrieve::{RetrievalError, RetrievalMetadata, RetrievedDigest, RetrievedDocument},
17 utils::{openpgp::PublicKey, url::Urlify},
18 validate::source::{KeySource, KeySourceError},
19};
20
21#[derive(Clone, Debug)]
23pub struct RetrievedSbom {
24 pub discovered: DiscoveredSbom,
26
27 pub data: Bytes,
29 pub signature: Option<String>,
31
32 pub sha256: Option<RetrievedDigest<Sha256>>,
34 pub sha512: Option<RetrievedDigest<Sha512>>,
36
37 pub metadata: RetrievalMetadata,
39}
40
41impl Urlify for RetrievedSbom {
42 fn url(&self) -> &Url {
43 &self.url
44 }
45
46 fn relative_base_and_url(&self) -> Option<(&Url, String)> {
47 self.discovered.relative_base_and_url()
48 }
49}
50
51impl Deref for RetrievedSbom {
52 type Target = DiscoveredSbom;
53
54 fn deref(&self) -> &Self::Target {
55 &self.discovered
56 }
57}
58
59impl DerefMut for RetrievedSbom {
60 fn deref_mut(&mut self) -> &mut Self::Target {
61 &mut self.discovered
62 }
63}
64
65impl RetrievedDocument for RetrievedSbom {
66 type Discovered = DiscoveredSbom;
67}
68
69pub struct RetrievalContext<'c> {
70 pub discovered: &'c DiscoveredContext<'c>,
71 pub keys: &'c Vec<PublicKey>,
72}
73
74impl<'c> Deref for RetrievalContext<'c> {
75 type Target = DiscoveredContext<'c>;
76
77 fn deref(&self) -> &Self::Target {
78 self.discovered
79 }
80}
81
82pub trait RetrievedVisitor<S: Source> {
83 type Error: std::fmt::Display + Debug;
84 type Context;
85
86 fn visit_context(
87 &self,
88 context: &RetrievalContext,
89 ) -> impl Future<Output = Result<Self::Context, Self::Error>>;
90
91 fn visit_sbom(
92 &self,
93 context: &Self::Context,
94 result: Result<RetrievedSbom, RetrievalError<DiscoveredSbom, S>>,
95 ) -> impl Future<Output = Result<(), Self::Error>>;
96}
97
98impl<F, E, Fut, S> RetrievedVisitor<S> for F
99where
100 F: Fn(Result<RetrievedSbom, RetrievalError<DiscoveredSbom, S>>) -> Fut,
101 Fut: Future<Output = Result<(), E>>,
102 E: std::fmt::Display + Debug,
103 S: Source,
104{
105 type Error = E;
106 type Context = ();
107
108 async fn visit_context(
109 &self,
110 _context: &RetrievalContext<'_>,
111 ) -> Result<Self::Context, Self::Error> {
112 Ok(())
113 }
114
115 async fn visit_sbom(
116 &self,
117 _ctx: &Self::Context,
118 outcome: Result<RetrievedSbom, RetrievalError<DiscoveredSbom, S>>,
119 ) -> Result<(), Self::Error> {
120 self(outcome).await
121 }
122}
123
124pub struct RetrievingVisitor<V: RetrievedVisitor<S>, S: Source + KeySource> {
125 visitor: V,
126 source: S,
127}
128
129impl<V, S> RetrievingVisitor<V, S>
130where
131 V: RetrievedVisitor<S>,
132 S: Source + KeySource,
133{
134 pub fn new(source: S, visitor: V) -> Self {
135 Self { visitor, source }
136 }
137}
138
139#[derive(Debug, thiserror::Error)]
140pub enum Error<VE, SE, KSE>
141where
142 VE: std::fmt::Display + Debug,
143 SE: std::fmt::Display + Debug,
144 KSE: std::fmt::Display + Debug,
145{
146 #[error("Source error: {0}")]
147 Source(SE),
148 #[error("Key source error: {0}")]
149 KeySource(KeySourceError<KSE>),
150 #[error(transparent)]
151 Visitor(VE),
152}
153
154impl<V, S> DiscoveredVisitor for RetrievingVisitor<V, S>
155where
156 V: RetrievedVisitor<S>,
157 S: Source + KeySource,
158{
159 type Error =
160 Error<V::Error, <S as walker_common::source::Source>::Error, <S as KeySource>::Error>;
161 type Context = V::Context;
162
163 async fn visit_context(
164 &self,
165 context: &DiscoveredContext<'_>,
166 ) -> Result<Self::Context, Self::Error> {
167 let mut keys = Vec::with_capacity(context.metadata.keys.len());
168
169 for key in &context.metadata.keys {
170 keys.push(
171 self.source
172 .load_public_key(key.into())
173 .await
174 .map_err(Error::KeySource)?,
175 );
176 }
177
178 log::info!(
179 "Loaded {} public key{}",
180 keys.len(),
181 if keys.len() != 1 {
182 "s"
183 } else {
184 Default::default()
185 }
186 );
187 if log::log_enabled!(log::Level::Debug) {
188 for key in keys.iter().flat_map(|k| &k.certs) {
189 log::debug!(" {}", key.key_handle());
190 for id in key.userids() {
191 log::debug!(" {}", id.userid());
192 }
193 }
194 }
195
196 self.visitor
197 .visit_context(&RetrievalContext {
198 keys: &keys,
199 discovered: context,
200 })
201 .await
202 .map_err(Error::Visitor)
203 }
204
205 async fn visit_sbom(
206 &self,
207 context: &Self::Context,
208 discovered: DiscoveredSbom,
209 ) -> Result<(), Self::Error> {
210 let sbom = self
211 .source
212 .load_sbom(discovered.clone())
213 .await
214 .map_err(|err| RetrievalError::Source { err, discovered });
215
216 self.visitor
217 .visit_sbom(context, sbom)
218 .await
219 .map_err(Error::Visitor)?;
220
221 Ok(())
222 }
223}