Skip to main content

dk_agent_sdk/
session.rs

1use dk_protocol::agent_service_client::AgentServiceClient;
2use dk_protocol::{
3    merge_response, Change as ProtoChange, ChangeType, ContextDepth, ContextRequest, MergeRequest,
4    SubmitRequest, VerifyRequest, WatchRequest,
5};
6use tonic::transport::Channel;
7use tokio_stream::StreamExt;
8
9use crate::error::Result;
10use crate::types::*;
11
12/// A stateful agent session bound to a changeset on the server.
13///
14/// Obtained from [`crate::AgentClient::init`].  All operations (context,
15/// submit, verify, merge, watch) are scoped to this session's changeset.
16pub struct Session {
17    client: AgentServiceClient<Channel>,
18    /// The server-assigned session identifier.
19    pub session_id: String,
20    /// The changeset created by the CONNECT handshake.
21    pub changeset_id: String,
22    /// The codebase version at the time of connection.
23    pub codebase_version: String,
24}
25
26impl Session {
27    pub(crate) fn new(client: AgentServiceClient<Channel>, result: ConnectResult) -> Self {
28        Self {
29            client,
30            session_id: result.session_id,
31            changeset_id: result.changeset_id,
32            codebase_version: result.codebase_version,
33        }
34    }
35
36    /// Query the semantic code graph for symbols matching `query`.
37    pub async fn context(
38        &mut self,
39        query: &str,
40        depth: Depth,
41        max_tokens: u32,
42    ) -> Result<ContextResult> {
43        let proto_depth = match depth {
44            Depth::Signatures => ContextDepth::Signatures as i32,
45            Depth::Full => ContextDepth::Full as i32,
46            Depth::CallGraph => ContextDepth::CallGraph as i32,
47        };
48
49        let resp = self
50            .client
51            .context(ContextRequest {
52                session_id: self.session_id.clone(),
53                query: query.to_string(),
54                depth: proto_depth,
55                include_tests: false,
56                include_dependencies: false,
57                max_tokens,
58            })
59            .await?
60            .into_inner();
61
62        Ok(ContextResult {
63            symbols: resp.symbols,
64            call_graph: resp.call_graph,
65            dependencies: resp.dependencies,
66            estimated_tokens: resp.estimated_tokens,
67        })
68    }
69
70    /// Submit a batch of code changes to the current changeset.
71    pub async fn submit(&mut self, changes: Vec<Change>) -> Result<SubmitResult> {
72        let proto_changes: Vec<ProtoChange> = changes
73            .iter()
74            .map(|c| match c {
75                Change::Add { path, content } => ProtoChange {
76                    r#type: ChangeType::AddFunction as i32,
77                    symbol_name: String::new(),
78                    file_path: path.clone(),
79                    old_symbol_id: None,
80                    new_source: content.clone(),
81                    rationale: String::new(),
82                },
83                Change::Modify { path, content } => ProtoChange {
84                    r#type: ChangeType::ModifyFunction as i32,
85                    symbol_name: String::new(),
86                    file_path: path.clone(),
87                    old_symbol_id: None,
88                    new_source: content.clone(),
89                    rationale: String::new(),
90                },
91                Change::Delete { path } => ProtoChange {
92                    r#type: ChangeType::DeleteFunction as i32,
93                    symbol_name: String::new(),
94                    file_path: path.clone(),
95                    old_symbol_id: None,
96                    new_source: String::new(),
97                    rationale: String::new(),
98                },
99            })
100            .collect();
101
102        let resp = self
103            .client
104            .submit(SubmitRequest {
105                session_id: self.session_id.clone(),
106                intent: String::new(),
107                changes: proto_changes,
108                changeset_id: self.changeset_id.clone(),
109            })
110            .await?
111            .into_inner();
112
113        let status = format!("{:?}", resp.status());
114        Ok(SubmitResult {
115            changeset_id: resp.changeset_id,
116            status,
117            errors: resp.errors,
118        })
119    }
120
121    /// Trigger the verification pipeline and collect all step results.
122    pub async fn verify(&mut self) -> Result<Vec<VerifyStepResult>> {
123        let mut stream = self
124            .client
125            .verify(VerifyRequest {
126                session_id: self.session_id.clone(),
127                changeset_id: self.changeset_id.clone(),
128            })
129            .await?
130            .into_inner();
131
132        let mut results = Vec::new();
133        while let Some(step) = stream.next().await {
134            results.push(step?);
135        }
136        Ok(results)
137    }
138
139    /// Merge the current changeset into a Git commit.
140    ///
141    /// If `force` is `true`, the recency guard is bypassed (use after the
142    /// caller has acknowledged an [`MergeResult::OverwriteWarning`]).
143    pub async fn merge(&mut self, message: &str, force: bool) -> Result<MergeResult> {
144        let resp = self
145            .client
146            .merge(MergeRequest {
147                session_id: self.session_id.clone(),
148                changeset_id: self.changeset_id.clone(),
149                commit_message: message.to_string(),
150                force,
151            })
152            .await?
153            .into_inner();
154
155        match resp.result {
156            Some(merge_response::Result::Success(s)) => Ok(MergeResult::Success(s)),
157            Some(merge_response::Result::Conflict(c)) => Ok(MergeResult::Conflict(c)),
158            Some(merge_response::Result::OverwriteWarning(w)) => {
159                Ok(MergeResult::OverwriteWarning(w))
160            }
161            None => Err(tonic::Status::internal("empty merge response").into()),
162        }
163    }
164
165    /// Subscribe to repository events (other agents' changes, merges, etc.).
166    pub async fn watch(
167        &mut self,
168        filter: Filter,
169    ) -> Result<tonic::Streaming<WatchEvent>> {
170        let filter_str = match filter {
171            Filter::All => "all",
172            Filter::Symbols => "symbols",
173            Filter::Files => "files",
174        };
175
176        let stream = self
177            .client
178            .watch(WatchRequest {
179                session_id: self.session_id.clone(),
180                repo_id: String::new(),
181                filter: filter_str.to_string(),
182            })
183            .await?
184            .into_inner();
185
186        Ok(stream)
187    }
188}