1use dk_protocol::agent_service_client::AgentServiceClient;
2use dk_protocol::{
3 merge_response, Change as ProtoChange, ChangeType, ContextDepth, ContextRequest, MergeRequest,
4 SubmitRequest, VerifyRequest, WatchRequest,
5};
6use tonic::transport::Channel;
7use tokio_stream::StreamExt;
8
9use crate::error::Result;
10use crate::types::*;
11
12pub struct Session {
17 client: AgentServiceClient<Channel>,
18 pub session_id: String,
20 pub changeset_id: String,
22 pub codebase_version: String,
24}
25
26impl Session {
27 pub(crate) fn new(client: AgentServiceClient<Channel>, result: ConnectResult) -> Self {
28 Self {
29 client,
30 session_id: result.session_id,
31 changeset_id: result.changeset_id,
32 codebase_version: result.codebase_version,
33 }
34 }
35
36 pub async fn context(
38 &mut self,
39 query: &str,
40 depth: Depth,
41 max_tokens: u32,
42 ) -> Result<ContextResult> {
43 let proto_depth = match depth {
44 Depth::Signatures => ContextDepth::Signatures as i32,
45 Depth::Full => ContextDepth::Full as i32,
46 Depth::CallGraph => ContextDepth::CallGraph as i32,
47 };
48
49 let resp = self
50 .client
51 .context(ContextRequest {
52 session_id: self.session_id.clone(),
53 query: query.to_string(),
54 depth: proto_depth,
55 include_tests: false,
56 include_dependencies: false,
57 max_tokens,
58 })
59 .await?
60 .into_inner();
61
62 Ok(ContextResult {
63 symbols: resp.symbols,
64 call_graph: resp.call_graph,
65 dependencies: resp.dependencies,
66 estimated_tokens: resp.estimated_tokens,
67 })
68 }
69
70 pub async fn submit(&mut self, changes: Vec<Change>) -> Result<SubmitResult> {
72 let proto_changes: Vec<ProtoChange> = changes
73 .iter()
74 .map(|c| match c {
75 Change::Add { path, content } => ProtoChange {
76 r#type: ChangeType::AddFunction as i32,
77 symbol_name: String::new(),
78 file_path: path.clone(),
79 old_symbol_id: None,
80 new_source: content.clone(),
81 rationale: String::new(),
82 },
83 Change::Modify { path, content } => ProtoChange {
84 r#type: ChangeType::ModifyFunction as i32,
85 symbol_name: String::new(),
86 file_path: path.clone(),
87 old_symbol_id: None,
88 new_source: content.clone(),
89 rationale: String::new(),
90 },
91 Change::Delete { path } => ProtoChange {
92 r#type: ChangeType::DeleteFunction as i32,
93 symbol_name: String::new(),
94 file_path: path.clone(),
95 old_symbol_id: None,
96 new_source: String::new(),
97 rationale: String::new(),
98 },
99 })
100 .collect();
101
102 let resp = self
103 .client
104 .submit(SubmitRequest {
105 session_id: self.session_id.clone(),
106 intent: String::new(),
107 changes: proto_changes,
108 changeset_id: self.changeset_id.clone(),
109 })
110 .await?
111 .into_inner();
112
113 let status = format!("{:?}", resp.status());
114 Ok(SubmitResult {
115 changeset_id: resp.changeset_id,
116 status,
117 errors: resp.errors,
118 })
119 }
120
121 pub async fn verify(&mut self) -> Result<Vec<VerifyStepResult>> {
123 let mut stream = self
124 .client
125 .verify(VerifyRequest {
126 session_id: self.session_id.clone(),
127 changeset_id: self.changeset_id.clone(),
128 })
129 .await?
130 .into_inner();
131
132 let mut results = Vec::new();
133 while let Some(step) = stream.next().await {
134 results.push(step?);
135 }
136 Ok(results)
137 }
138
139 pub async fn merge(&mut self, message: &str, force: bool) -> Result<MergeResult> {
144 let resp = self
145 .client
146 .merge(MergeRequest {
147 session_id: self.session_id.clone(),
148 changeset_id: self.changeset_id.clone(),
149 commit_message: message.to_string(),
150 force,
151 })
152 .await?
153 .into_inner();
154
155 match resp.result {
156 Some(merge_response::Result::Success(s)) => Ok(MergeResult::Success(s)),
157 Some(merge_response::Result::Conflict(c)) => Ok(MergeResult::Conflict(c)),
158 Some(merge_response::Result::OverwriteWarning(w)) => {
159 Ok(MergeResult::OverwriteWarning(w))
160 }
161 None => Err(tonic::Status::internal("empty merge response").into()),
162 }
163 }
164
165 pub async fn watch(
167 &mut self,
168 filter: Filter,
169 ) -> Result<tonic::Streaming<WatchEvent>> {
170 let filter_str = match filter {
171 Filter::All => "all",
172 Filter::Symbols => "symbols",
173 Filter::Files => "files",
174 };
175
176 let stream = self
177 .client
178 .watch(WatchRequest {
179 session_id: self.session_id.clone(),
180 repo_id: String::new(),
181 filter: filter_str.to_string(),
182 })
183 .await?
184 .into_inner();
185
186 Ok(stream)
187 }
188}