1use std::{collections::HashMap, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11 ComputeStateSignalsRequest, ComputeStateSignalsResponse, GetRepoSignalHealthRequest,
12 PathSymbolRef, RepoSignalHealthReport, RiskSignal as ProtoRiskSignal,
13 SignalAnchor as ProtoSignalAnchor, SignalHealthEntry, SignalUpdate,
14 SubscribeSignalUpdatesRequest, signal_service_server::SignalService,
15};
16use objects::object::{ChangeId, RiskSignal, RiskSignalBlob, State};
17use repo::Repository;
18use tokio_stream::wrappers::ReceiverStream;
19use tonic::{Request, Response, Status};
20
21use super::{GrpcLocalService, to_status};
22
23#[derive(Clone)]
24pub struct LocalSignalService {
25 inner: GrpcLocalService,
26}
27
28impl LocalSignalService {
29 pub fn new(inner: GrpcLocalService) -> Self {
30 Self { inner }
31 }
32
33 fn repo(&self) -> &Repository {
34 self.inner.repo()
35 }
36}
37
38#[tonic::async_trait]
39impl SignalService for LocalSignalService {
40 type SubscribeSignalUpdatesStream =
41 Pin<Box<dyn Stream<Item = Result<SignalUpdate, Status>> + Send>>;
42
43 async fn compute_state_signals(
44 &self,
45 request: Request<ComputeStateSignalsRequest>,
46 ) -> Result<Response<ComputeStateSignalsResponse>, Status> {
47 let req = request.into_inner();
48 if req.state_id.is_empty() {
49 return Err(Status::invalid_argument("state_id is required"));
50 }
51 let change_id = ChangeId::try_from_slice(&req.state_id)
52 .map_err(|err| Status::invalid_argument(format!("invalid state_id: {err}")))?;
53 let state = self
54 .repo()
55 .store()
56 .get_state(&change_id)
57 .map_err(to_status)?
58 .ok_or_else(|| Status::not_found(format!("state {change_id} not found")))?;
59 let signals = load_signals(self.repo(), &state)?;
64 let proto_signals = signals
65 .iter()
66 .map(|s| signal_to_proto(s, "visible"))
67 .collect();
68 Ok(Response::new(ComputeStateSignalsResponse {
69 signals: proto_signals,
70 tick_budget: 3,
71 }))
72 }
73
74 async fn get_repo_signal_health(
75 &self,
76 request: Request<GetRepoSignalHealthRequest>,
77 ) -> Result<Response<RepoSignalHealthReport>, Status> {
78 let req = request.into_inner();
79 let window = if req.window_states == 0 {
80 DEFAULT_HEALTH_WINDOW
81 } else {
82 req.window_states.min(MAX_HEALTH_WINDOW) as usize
83 };
84 let states = walk_recent_states(self.repo(), window).map_err(to_status)?;
88 let visited = states.len() as u32;
89 let mut per_module: HashMap<String, u32> = HashMap::new();
90 for state in &states {
91 let signals = load_signals(self.repo(), state)?;
92 let mut seen_modules = std::collections::HashSet::new();
96 for sig in &signals {
97 let key = sig.producer.module.clone();
98 if seen_modules.insert(key.clone()) {
99 *per_module.entry(key).or_insert(0) += 1;
100 }
101 }
102 }
103 let warn_threshold = 0.5_f32;
104 let entries = per_module
105 .into_iter()
106 .map(|(module_id, hit_count)| {
107 let fire_rate = if visited == 0 {
108 0.0
109 } else {
110 hit_count as f32 / visited as f32
111 };
112 SignalHealthEntry {
113 module_id,
114 fire_rate,
115 warn: fire_rate > warn_threshold,
116 }
117 })
118 .collect();
119 Ok(Response::new(RepoSignalHealthReport {
120 entries,
121 window_states: visited,
122 }))
123 }
124
125 async fn subscribe_signal_updates(
126 &self,
127 _request: Request<SubscribeSignalUpdatesRequest>,
128 ) -> Result<Response<Self::SubscribeSignalUpdatesStream>, Status> {
129 let (_tx, rx) = tokio::sync::mpsc::channel::<Result<SignalUpdate, Status>>(1);
134 Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
135 }
136}
137
138const DEFAULT_HEALTH_WINDOW: usize = 200;
139const MAX_HEALTH_WINDOW: u32 = 5_000;
140
141fn load_signals(repo: &Repository, state: &State) -> Result<Vec<RiskSignal>, Status> {
142 let Some(hash) = state.risk_signals else {
143 return Ok(Vec::new());
144 };
145 let blob = repo
146 .store()
147 .get_blob(&hash)
148 .map_err(to_status)?
149 .ok_or_else(|| {
150 Status::data_loss(format!(
151 "risk_signals blob {hash} referenced by state {} is missing",
152 state.change_id
153 ))
154 })?;
155 let parsed = RiskSignalBlob::decode(blob.content()).map_err(|err| {
156 Status::internal(format!(
157 "failed to decode risk signals on state {}: {err}",
158 state.change_id
159 ))
160 })?;
161 Ok(parsed.signals)
162}
163
164fn walk_recent_states(repo: &Repository, window: usize) -> objects::error::Result<Vec<State>> {
165 let mut out = Vec::new();
166 let mut cursor = repo.head()?;
167 while let Some(id) = cursor {
168 if out.len() >= window {
169 break;
170 }
171 let Some(state) = repo.store().get_state(&id)? else {
172 break;
173 };
174 let parent = state.parents.first().copied();
175 out.push(state);
176 cursor = parent;
177 }
178 Ok(out)
179}
180
181fn signal_to_proto(sig: &RiskSignal, visibility: &str) -> ProtoRiskSignal {
182 let (start_line, end_line) = sig.anchor.line_range.unwrap_or((0, 0));
183 ProtoRiskSignal {
184 kind: sig.kind.as_str().to_string(),
185 anchor: Some(ProtoSignalAnchor {
186 file: sig.anchor.file.clone(),
187 symbol: sig.anchor.symbol.clone().unwrap_or_default(),
188 start_line,
189 end_line,
190 }),
191 reason: sig.reason.clone(),
192 producer_module: sig.producer.module.clone(),
193 producer_version: sig.producer.version,
194 computed_at: Some(prost_types::Timestamp {
195 seconds: sig.computed_at,
196 nanos: 0,
197 }),
198 visibility: visibility.to_string(),
199 }
200}
201
202#[allow(dead_code)]
204fn make_path_symbol(file: &str, symbol: &str) -> PathSymbolRef {
205 PathSymbolRef {
206 file: file.to_string(),
207 symbol: symbol.to_string(),
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use objects::object::{Attribution, Blob, Principal, ProducerId, RiskSignalKind, SignalAnchor};
214 use tempfile::TempDir;
215
216 use super::*;
217
218 fn fresh_repo() -> (TempDir, Repository) {
219 let temp = TempDir::new().unwrap();
220 let repo = Repository::init_default(temp.path()).unwrap();
221 (temp, repo)
222 }
223
224 fn local_service(repo: Repository) -> LocalSignalService {
225 let dedup = std::sync::Arc::new(
226 repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap(),
227 );
228 LocalSignalService::new(GrpcLocalService::new(std::sync::Arc::new(repo), dedup))
229 }
230
231 fn snapshot_with_signals(repo: &Repository, signals: Vec<RiskSignal>) -> ChangeId {
232 let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
233 let snapshot = repo
234 .snapshot_with_attribution(Some("test snapshot".to_string()), None, attribution)
235 .unwrap();
236 let blob = RiskSignalBlob::new(signals).encode().unwrap();
237 let hash = repo.store().put_blob(&Blob::new(blob)).unwrap();
238 let state = repo
239 .store()
240 .get_state(&snapshot.change_id)
241 .unwrap()
242 .unwrap();
243 let updated = state.with_risk_signals(hash);
244 repo.store().put_state(&updated).unwrap();
245 snapshot.change_id
246 }
247
248 fn sample_signal(kind: RiskSignalKind, reason: &str) -> RiskSignal {
249 RiskSignal {
250 kind,
251 anchor: SignalAnchor::symbol("src/lib.rs", "foo"),
252 reason: reason.to_string(),
253 producer: ProducerId::new("novelty.tree_sitter", 1),
254 computed_at: 1_700_000_000,
255 computed_against: None,
256 }
257 }
258
259 #[tokio::test]
260 async fn compute_state_signals_returns_persisted_signals() {
261 let (_t, repo) = fresh_repo();
262 let signal = sample_signal(RiskSignalKind::Novelty, "novel control flow shape");
263 let state_id = snapshot_with_signals(&repo, vec![signal]);
264 let svc = local_service(repo);
265 let resp = svc
266 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
267 repo_path: String::new(),
268 state_id: state_id.as_bytes().to_vec(),
269 prior_state_id: Vec::new(),
270 }))
271 .await
272 .unwrap();
273 let signals = resp.into_inner().signals;
274 assert_eq!(signals.len(), 1);
275 assert_eq!(signals[0].kind, "novelty");
276 assert_eq!(signals[0].reason, "novel control flow shape");
277 assert_eq!(signals[0].visibility, "visible");
278 }
279
280 #[tokio::test]
281 async fn compute_state_signals_returns_empty_when_state_has_no_signals() {
282 let (_t, repo) = fresh_repo();
283 let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
284 let snap = repo
285 .snapshot_with_attribution(Some("plain".to_string()), None, attribution)
286 .unwrap();
287 let svc = local_service(repo);
288 let resp = svc
289 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
290 repo_path: String::new(),
291 state_id: snap.change_id.as_bytes().to_vec(),
292 prior_state_id: Vec::new(),
293 }))
294 .await
295 .unwrap();
296 assert!(resp.into_inner().signals.is_empty());
297 }
298
299 #[tokio::test]
300 async fn invalid_state_id_returns_invalid_argument() {
301 let (_t, repo) = fresh_repo();
302 let svc = local_service(repo);
303 let err = svc
304 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
305 repo_path: String::new(),
306 state_id: "not-a-change-id".into(),
307 prior_state_id: Vec::new(),
308 }))
309 .await
310 .unwrap_err();
311 assert_eq!(err.code(), tonic::Code::InvalidArgument);
312 }
313
314 #[tokio::test]
315 async fn signal_health_groups_by_module_id() {
316 let (_t, repo) = fresh_repo();
317 let novelty = sample_signal(RiskSignalKind::Novelty, "novel");
318 snapshot_with_signals(&repo, vec![novelty]);
319 let svc = local_service(repo);
320 let resp = svc
321 .get_repo_signal_health(Request::new(GetRepoSignalHealthRequest {
322 repo_path: String::new(),
323 window_states: 50,
324 }))
325 .await
326 .unwrap();
327 let report = resp.into_inner();
328 assert!(report.window_states >= 1);
329 let entry = report
330 .entries
331 .iter()
332 .find(|e| e.module_id == "novelty.tree_sitter")
333 .expect("novelty entry present");
334 assert!(entry.fire_rate > 0.0 && entry.fire_rate <= 1.0);
335 }
336}