Skip to main content

daemon/grpc_local_impl/
signal.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Local-mode `SignalService`. Reads risk signals attached to states by W1
3//! (`State.risk_signals`) and reports per-signal fire rates over a rolling
4//! window. The actual signal *computation* lands in R3 (`crates/state_review`);
5//! this service exposes whatever's already on disk.
6
7use std::{collections::HashMap, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11    ComputeStateSignalsRequest, ComputeStateSignalsResponse, GetRepoSignalHealthRequest,
12    PathSymbolRef, RepoSignalHealthReport, RiskSignal as ProtoRiskSignal,
13    SignalAnchor as ProtoSignalAnchor, SignalHealthEntry, SignalUpdate,
14    SubscribeSignalUpdatesRequest, signal_service_server::SignalService,
15};
16use objects::object::{ChangeId, RiskSignal, RiskSignalBlob, State};
17use repo::Repository;
18use tokio_stream::wrappers::ReceiverStream;
19use tonic::{Request, Response, Status};
20
21use super::{GrpcLocalService, to_status};
22
23#[derive(Clone)]
24pub struct LocalSignalService {
25    inner: GrpcLocalService,
26}
27
28impl LocalSignalService {
29    pub fn new(inner: GrpcLocalService) -> Self {
30        Self { inner }
31    }
32
33    fn repo(&self) -> &Repository {
34        self.inner.repo()
35    }
36}
37
38#[tonic::async_trait]
39impl SignalService for LocalSignalService {
40    type SubscribeSignalUpdatesStream =
41        Pin<Box<dyn Stream<Item = Result<SignalUpdate, Status>> + Send>>;
42
43    async fn compute_state_signals(
44        &self,
45        request: Request<ComputeStateSignalsRequest>,
46    ) -> Result<Response<ComputeStateSignalsResponse>, Status> {
47        let req = request.into_inner();
48        if req.state_id.is_empty() {
49            return Err(Status::invalid_argument("state_id is required"));
50        }
51        let change_id = ChangeId::try_from_slice(&req.state_id)
52            .map_err(|err| Status::invalid_argument(format!("invalid state_id: {err}")))?;
53        let state = self
54            .repo()
55            .store()
56            .get_state(&change_id)
57            .map_err(to_status)?
58            .ok_or_else(|| Status::not_found(format!("state {change_id} not found")))?;
59        // Real R3 computation lives in `crates/state_review`. For W2 this
60        // service surfaces whatever signals already live on the state via
61        // `state.risk_signals`. When R3 wires through, replace this with
62        // a call into `state_review::registry::run_all`.
63        let signals = load_signals(self.repo(), &state)?;
64        let proto_signals = signals
65            .iter()
66            .map(|s| signal_to_proto(s, "visible"))
67            .collect();
68        Ok(Response::new(ComputeStateSignalsResponse {
69            signals: proto_signals,
70            tick_budget: 3,
71        }))
72    }
73
74    async fn get_repo_signal_health(
75        &self,
76        request: Request<GetRepoSignalHealthRequest>,
77    ) -> Result<Response<RepoSignalHealthReport>, Status> {
78        let req = request.into_inner();
79        let window = if req.window_states == 0 {
80            DEFAULT_HEALTH_WINDOW
81        } else {
82            req.window_states.min(MAX_HEALTH_WINDOW) as usize
83        };
84        // Walk recent states from HEAD's first-parent chain, up to `window`.
85        // For each, collect any `RiskSignal`s on disk. Aggregate per
86        // module_id; fire_rate = states-with-signals / states-considered.
87        let states = walk_recent_states(self.repo(), window).map_err(to_status)?;
88        let visited = states.len() as u32;
89        let mut per_module: HashMap<String, u32> = HashMap::new();
90        for state in &states {
91            let signals = load_signals(self.repo(), state)?;
92            // One state contributes at most once per module so a noisy
93            // module isn't double-counted by firing many signals on the
94            // same state.
95            let mut seen_modules = std::collections::HashSet::new();
96            for sig in &signals {
97                let key = sig.producer.module.clone();
98                if seen_modules.insert(key.clone()) {
99                    *per_module.entry(key).or_insert(0) += 1;
100                }
101            }
102        }
103        let warn_threshold = 0.5_f32;
104        let entries = per_module
105            .into_iter()
106            .map(|(module_id, hit_count)| {
107                let fire_rate = if visited == 0 {
108                    0.0
109                } else {
110                    hit_count as f32 / visited as f32
111                };
112                SignalHealthEntry {
113                    module_id,
114                    fire_rate,
115                    warn: fire_rate > warn_threshold,
116                }
117            })
118            .collect();
119        Ok(Response::new(RepoSignalHealthReport {
120            entries,
121            window_states: visited,
122        }))
123    }
124
125    async fn subscribe_signal_updates(
126        &self,
127        _request: Request<SubscribeSignalUpdatesRequest>,
128    ) -> Result<Response<Self::SubscribeSignalUpdatesStream>, Status> {
129        // W2 lands the contract; live event broadcasting wires up in R3 once
130        // the signal registry can recompute on capture. For now we open a
131        // channel that closes immediately — clients see EOF and reconnect
132        // when the producer becomes available.
133        let (_tx, rx) = tokio::sync::mpsc::channel::<Result<SignalUpdate, Status>>(1);
134        Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
135    }
136}
137
138const DEFAULT_HEALTH_WINDOW: usize = 200;
139const MAX_HEALTH_WINDOW: u32 = 5_000;
140
141fn load_signals(repo: &Repository, state: &State) -> Result<Vec<RiskSignal>, Status> {
142    let Some(hash) = state.risk_signals else {
143        return Ok(Vec::new());
144    };
145    let blob = repo
146        .store()
147        .get_blob(&hash)
148        .map_err(to_status)?
149        .ok_or_else(|| {
150            Status::data_loss(format!(
151                "risk_signals blob {hash} referenced by state {} is missing",
152                state.change_id
153            ))
154        })?;
155    let parsed = RiskSignalBlob::decode(blob.content()).map_err(|err| {
156        Status::internal(format!(
157            "failed to decode risk signals on state {}: {err}",
158            state.change_id
159        ))
160    })?;
161    Ok(parsed.signals)
162}
163
164fn walk_recent_states(repo: &Repository, window: usize) -> objects::error::Result<Vec<State>> {
165    let mut out = Vec::new();
166    let mut cursor = repo.head()?;
167    while let Some(id) = cursor {
168        if out.len() >= window {
169            break;
170        }
171        let Some(state) = repo.store().get_state(&id)? else {
172            break;
173        };
174        let parent = state.parents.first().copied();
175        out.push(state);
176        cursor = parent;
177    }
178    Ok(out)
179}
180
181fn signal_to_proto(sig: &RiskSignal, visibility: &str) -> ProtoRiskSignal {
182    let (start_line, end_line) = sig.anchor.line_range.unwrap_or((0, 0));
183    ProtoRiskSignal {
184        kind: sig.kind.as_str().to_string(),
185        anchor: Some(ProtoSignalAnchor {
186            file: sig.anchor.file.clone(),
187            symbol: sig.anchor.symbol.clone().unwrap_or_default(),
188            start_line,
189            end_line,
190        }),
191        reason: sig.reason.clone(),
192        producer_module: sig.producer.module.clone(),
193        producer_version: sig.producer.version,
194        computed_at: Some(prost_types::Timestamp {
195            seconds: sig.computed_at,
196            nanos: 0,
197        }),
198        visibility: visibility.to_string(),
199    }
200}
201
202// Small helper kept private; exported via PathSymbolRef wherever needed.
203#[allow(dead_code)]
204fn make_path_symbol(file: &str, symbol: &str) -> PathSymbolRef {
205    PathSymbolRef {
206        file: file.to_string(),
207        symbol: symbol.to_string(),
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use objects::object::{Attribution, Blob, Principal, ProducerId, RiskSignalKind, SignalAnchor};
214    use tempfile::TempDir;
215
216    use super::*;
217
218    fn fresh_repo() -> (TempDir, Repository) {
219        let temp = TempDir::new().unwrap();
220        let repo = Repository::init_default(temp.path()).unwrap();
221        (temp, repo)
222    }
223
224    fn local_service(repo: Repository) -> LocalSignalService {
225        let dedup = std::sync::Arc::new(
226            repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap(),
227        );
228        LocalSignalService::new(GrpcLocalService::new(std::sync::Arc::new(repo), dedup))
229    }
230
231    fn snapshot_with_signals(repo: &Repository, signals: Vec<RiskSignal>) -> ChangeId {
232        let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
233        let snapshot = repo
234            .snapshot_with_attribution(Some("test snapshot".to_string()), None, attribution)
235            .unwrap();
236        let blob = RiskSignalBlob::new(signals).encode().unwrap();
237        let hash = repo.store().put_blob(&Blob::new(blob)).unwrap();
238        let state = repo
239            .store()
240            .get_state(&snapshot.change_id)
241            .unwrap()
242            .unwrap();
243        let updated = state.with_risk_signals(hash);
244        repo.store().put_state(&updated).unwrap();
245        snapshot.change_id
246    }
247
248    fn sample_signal(kind: RiskSignalKind, reason: &str) -> RiskSignal {
249        RiskSignal {
250            kind,
251            anchor: SignalAnchor::symbol("src/lib.rs", "foo"),
252            reason: reason.to_string(),
253            producer: ProducerId::new("novelty.tree_sitter", 1),
254            computed_at: 1_700_000_000,
255            computed_against: None,
256        }
257    }
258
259    #[tokio::test]
260    async fn compute_state_signals_returns_persisted_signals() {
261        let (_t, repo) = fresh_repo();
262        let signal = sample_signal(RiskSignalKind::Novelty, "novel control flow shape");
263        let state_id = snapshot_with_signals(&repo, vec![signal]);
264        let svc = local_service(repo);
265        let resp = svc
266            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
267                repo_path: String::new(),
268                state_id: state_id.as_bytes().to_vec(),
269                prior_state_id: Vec::new(),
270            }))
271            .await
272            .unwrap();
273        let signals = resp.into_inner().signals;
274        assert_eq!(signals.len(), 1);
275        assert_eq!(signals[0].kind, "novelty");
276        assert_eq!(signals[0].reason, "novel control flow shape");
277        assert_eq!(signals[0].visibility, "visible");
278    }
279
280    #[tokio::test]
281    async fn compute_state_signals_returns_empty_when_state_has_no_signals() {
282        let (_t, repo) = fresh_repo();
283        let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
284        let snap = repo
285            .snapshot_with_attribution(Some("plain".to_string()), None, attribution)
286            .unwrap();
287        let svc = local_service(repo);
288        let resp = svc
289            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
290                repo_path: String::new(),
291                state_id: snap.change_id.as_bytes().to_vec(),
292                prior_state_id: Vec::new(),
293            }))
294            .await
295            .unwrap();
296        assert!(resp.into_inner().signals.is_empty());
297    }
298
299    #[tokio::test]
300    async fn invalid_state_id_returns_invalid_argument() {
301        let (_t, repo) = fresh_repo();
302        let svc = local_service(repo);
303        let err = svc
304            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
305                repo_path: String::new(),
306                state_id: "not-a-change-id".into(),
307                prior_state_id: Vec::new(),
308            }))
309            .await
310            .unwrap_err();
311        assert_eq!(err.code(), tonic::Code::InvalidArgument);
312    }
313
314    #[tokio::test]
315    async fn signal_health_groups_by_module_id() {
316        let (_t, repo) = fresh_repo();
317        let novelty = sample_signal(RiskSignalKind::Novelty, "novel");
318        snapshot_with_signals(&repo, vec![novelty]);
319        let svc = local_service(repo);
320        let resp = svc
321            .get_repo_signal_health(Request::new(GetRepoSignalHealthRequest {
322                repo_path: String::new(),
323                window_states: 50,
324            }))
325            .await
326            .unwrap();
327        let report = resp.into_inner();
328        assert!(report.window_states >= 1);
329        let entry = report
330            .entries
331            .iter()
332            .find(|e| e.module_id == "novelty.tree_sitter")
333            .expect("novelty entry present");
334        assert!(entry.fire_rate > 0.0 && entry.fire_rate <= 1.0);
335    }
336}